diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml new file mode 100644 index 0000000000000..183d52d5c1d44 --- /dev/null +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -0,0 +1,77 @@ +## DO NOT MODIFY THIS FILE MANUALLY. This is part of auto-baselining from 1ES Pipeline Templates. Go to [https://aka.ms/1espt-autobaselining] for more details. + +pipelines: + 1624: + retail: + source: + credscan: + lastModifiedDate: 2024-10-25 + policheck: + lastModifiedDate: 2024-10-25 + eslint: + lastModifiedDate: 2024-10-25 + psscriptanalyzer: + lastModifiedDate: 2024-10-25 + armory: + lastModifiedDate: 2024-10-25 + usedNonDefaultBranch: true + 1299: + retail: + source: + credscan: + lastModifiedDate: 2024-10-25 + eslint: + lastModifiedDate: 2024-10-25 + psscriptanalyzer: + lastModifiedDate: 2024-10-25 + armory: + lastModifiedDate: 2024-10-25 + policheck: + lastModifiedDate: 2024-10-29 + binary: + credscan: + lastModifiedDate: 2024-10-25 + binskim: + lastModifiedDate: 2024-10-25 + spotbugs: + lastModifiedDate: 2024-10-25 + 1625: + retail: + source: + credscan: + lastModifiedDate: 2024-11-05 + policheck: + lastModifiedDate: 2024-11-05 + eslint: + lastModifiedDate: 2024-11-05 + psscriptanalyzer: + lastModifiedDate: 2024-11-05 + armory: + lastModifiedDate: 2024-11-05 + binary: + credscan: + lastModifiedDate: 2024-11-13 + binskim: + lastModifiedDate: 2024-11-13 + spotbugs: + lastModifiedDate: 2024-11-13 + 1626: + retail: + source: + credscan: + lastModifiedDate: 2024-11-13 + policheck: + lastModifiedDate: 2024-11-13 + eslint: + lastModifiedDate: 2024-11-13 + psscriptanalyzer: + lastModifiedDate: 2024-11-13 + armory: + lastModifiedDate: 2024-11-13 + binary: + credscan: + lastModifiedDate: 2024-11-13 + binskim: + lastModifiedDate: 2024-11-13 + spotbugs: + lastModifiedDate: 2024-11-13 diff --git a/.config/guardian/.gdnbaselines b/.config/guardian/.gdnbaselines new file mode 100644 index 0000000000000..a7ee2a4b69dda --- /dev/null +++ b/.config/guardian/.gdnbaselines @@ -0,0 +1,43 @@ +{ + "properties": { + "helpUri": "https://eng.ms/docs/microsoft-security/security/azure-security/cloudai-security-fundamentals-engineering/security-integration/guardian-wiki/microsoft-guardian/general/baselines" + }, + "version": "1.0.0", + "baselines": { + "default": { + "name": "default", + "createdDate": "2024-11-13 00:40:35Z", + "lastUpdatedDate": "2024-11-13 00:40:35Z" + } + }, + "results": { + "48f03e2797fc40ecea50f878a0268947c7e13db1b2fa51aa3981246844fc4c68": { + "signature": "48f03e2797fc40ecea50f878a0268947c7e13db1b2fa51aa3981246844fc4c68", + "alternativeSignatures": [], + "target": "ScanTelemetry_20241113003616898.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2024-11-13 00:40:35Z", + "expirationDate": "2025-05-02 01:29:47Z", + "justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 01:29:47Z" + }, + "9cb6eddb3f3e886ad06cae65f5886412ff0c5fb0b96d4e943e4efa237be617b1": { + "signature": "9cb6eddb3f3e886ad06cae65f5886412ff0c5fb0b96d4e943e4efa237be617b1", + "alternativeSignatures": [], + "target": "ScanTelemetry_20241113111547065.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2024-11-13 11:20:17Z", + "expirationDate": "2025-05-02 11:55:15Z", + "justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 11:55:15Z" + } + } +} \ No newline at end of file diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 0000000000000..6a76f7bcdbcb0 --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,7 @@ +name: "CodeQL config" +queries: + - uses: security-extended + - uses: security-and-quality +paths-ignore: + - tests + - build \ No newline at end of file diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 7144363717749..0cbaf24059390 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -8,7 +8,7 @@ on: jobs: validate: name: "validate" - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository uses: actions/checkout@v4 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index e4d1b91bab736..d1dc717c2a9c9 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -15,10 +15,14 @@ on: schedule: - cron: '41 13 * * 0' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: analyze: name: Analyze - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] permissions: actions: read contents: read @@ -55,6 +59,11 @@ jobs: java-version: '11' distribution: 'microsoft' + - if: ${{ matrix.language == 'javascript' }} + uses: actions/setup-node@v4 + with: + node-version: 20 + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - if: ${{ matrix.language != 'cpp' }} diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 32aed81092774..cf3bc598d02bb 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -8,7 +8,7 @@ on: [push, pull_request] jobs: validation: name: "Validation" - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 - uses: gradle/actions/wrapper-validation@v4 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index a196226a4b836..00960c848b107 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -8,7 +8,7 @@ permissions: jobs: triage: - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: github/issue-labeler@v3.4 with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2edbe2d814533..8d966d358de01 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,10 @@ on: - rel-* pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + jobs: optional-lint: name: Optional Lint @@ -32,23 +36,29 @@ jobs: lint-python-format: # Required workflow name: Python format - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: + contents: read + security-events: write steps: - uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v5 with: - # Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset. + # Use the version configured in target-version of [tool.black] section in pyproject.toml. python-version: "3.10" - name: Setup Rust uses: actions-rs/toolchain@v1 with: toolchain: stable components: rustfmt + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> "$GITHUB_PATH" - name: Install dependencies run: | - python -m pip install -r requirements-dev.txt - python -m pip install lintrunner lintrunner-adapters + set -e -x + python -m pip install --user -r requirements-dev.txt lintrunner init - name: Run lintrunner on all files run: | @@ -77,8 +87,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@master + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - name: Install ninja - run: python -m pip install --upgrade ninja + run: python -m pip install --user --upgrade ninja - name: Generate compile_commands.json run: | python tools/ci_build/build.py \ @@ -110,9 +124,12 @@ jobs: lint-js: name: Lint JavaScript - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 20 - uses: reviewdog/action-eslint@v1 with: reporter: github-pr-check diff --git a/.github/workflows/linux_training.yml b/.github/workflows/linux_training.yml new file mode 100644 index 0000000000000..d382cdf476283 --- /dev/null +++ b/.github/workflows/linux_training.yml @@ -0,0 +1,55 @@ +name: orttraining-linux-ci-pipeline +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + orttraining-linux-ci-pipeline: + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: + actions: read + contents: read + security-events: write + steps: + - uses: actions/checkout@v4 + - run: | + python3 -m pip install --user -r tools/ci_build/github/linux/python/requirements.txt + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/codeql/codeql-config.yml + languages: 'cpp' + - run: | + set -e -x + rm -rf build + python3 tools/ci_build/build.py --build_dir build --config Release --enable_training --skip_submodule_sync --parallel --update --build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: sarif-results + upload: failure-only + + - name: filter-sarif + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: sarif-results/cpp.sarif + output: sarif-results/cpp.sarif + + - name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: sarif-results/cpp.sarif \ No newline at end of file diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index d1a4366da45e2..b36b0aa555940 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -20,7 +20,7 @@ env: jobs: ARM64-Xcode16: - runs-on: macos-14 + runs-on: macos-15 env: xcode_version: 16 @@ -60,12 +60,16 @@ jobs: --use_xnnpack \ --use_binskim_compliant_compile_flags - ARM64-Xcode16-targeting-iphonesimulator-x86_64: - runs-on: macos-14 + ARM64-Xcode16-targeting-iphonesimulator: + runs-on: macos-15 env: xcode_version: 16 + strategy: + matrix: + target_arch: [x86_64, arm64] + timeout-minutes: 60 steps: @@ -87,16 +91,14 @@ jobs: - uses: actions/checkout@v4 - # Note: Setting onnxruntime_BUILD_UNIT_TESTS=OFF as a workaround for - # https://github.com/microsoft/onnxruntime/issues/22245. - - name: Build + - name: Build for iphonesimulator ${{ matrix.target_arch }} shell: bash run: | python ./tools/ci_build/build.py \ --build_dir ./build \ --update \ --build --parallel \ - --skip_tests \ + --test \ --build_apple_framework \ --use_xcode \ --use_coreml \ @@ -105,8 +107,7 @@ jobs: --ios \ --apple_deploy_target=13.0 \ --apple_sysroot=iphonesimulator \ - --osx_arch=x86_64 \ - --cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF + --osx_arch=${{ matrix.target_arch }} Vcpkg: runs-on: macos-13 diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml new file mode 100644 index 0000000000000..af890d88995be --- /dev/null +++ b/.github/workflows/pr_checks.yml @@ -0,0 +1,52 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +name: PR Checks + +on: + pull_request: + branches: + - main + +permissions: # set top-level default permissions as security best practice + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + auto-apply-fixes: + name: Suggest fixes + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: + contents: read + pull-requests: write + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Setup Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt + + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + + - name: Install dependencies and run lintrunner on all files + run: | + python -m pip install --user -r requirements-dev.txt + python -m pip install --user lintrunner lintrunner-adapters + lintrunner init + set +e + lintrunner f --all-files -v + exit 0 + - uses: parkerbxyz/suggest-changes@v2 + with: + comment: 'You can commit the suggested changes from lintrunner.' diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 6c4dc43847d0b..6d3e593d8694e 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -9,7 +9,7 @@ on: - include/onnxruntime/core/session/** - orttraining/orttraining/training_api/include/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: @@ -22,7 +22,7 @@ permissions: jobs: build: name: Generate C/C++ API docs - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 - name: Install doxygen and dependencies diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 862a7a70e33a2..7cca0969a168b 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -8,7 +8,7 @@ on: paths: - csharp/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: @@ -20,18 +20,17 @@ permissions: jobs: build: - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] env: DOCFXVERSION: 2.62.2 steps: - uses: actions/checkout@v4 - - name: Setup .NET - uses: actions/setup-dotnet@v4 - with: - dotnet-version: 8.0.x - name: Install DocFX run: | dotnet tool update -g docfx + - name: Update PATH + run: | + Add-Content -Value "$env:USERPROFILE\.dotnet\tools" -Encoding utf8 -Path $env:GITHUB_PATH # NOTE: We need to restore Microsoft.ML.OnnxRuntime.csproj manually to set IncludeMobileTargets=false # docfx doesn't seem to be able to do that properly resulting in build errors - name: Restore dependencies @@ -50,10 +49,12 @@ jobs: - name: Log source commit run: git rev-parse --short HEAD > csharp/ApiDocs/csharp/source-version.txt - name: Move C# docs into site + shell: pwsh run: | - mkdir -p _site/docs/api - rm -rf _site/docs/api/csharp - mv csharp/ApiDocs/csharp _site/docs/api/csharp + New-Item -Path _site/docs/api -Force -ItemType "Directory" | Out-Null + $OutputDirectory="_site/docs/api/csharp" + if (Test-Path $OutputDirectory) { Remove-Item -Recurse -Force $OutputDirectory } + Move-Item -Path csharp\ApiDocs\csharp -Destination $OutputDirectory - name: Upload docs artifact uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/publish-gh-pages.yml b/.github/workflows/publish-gh-pages.yml index 1818261b4b766..11745ce24f9e5 100644 --- a/.github/workflows/publish-gh-pages.yml +++ b/.github/workflows/publish-gh-pages.yml @@ -8,7 +8,7 @@ on: jobs: placeholder: - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Placeholder step to have workflow included in the GitHub web UI run: | diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 9e42dca708a17..d04669a13aab7 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -8,7 +8,7 @@ on: paths: - java/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: @@ -21,7 +21,7 @@ permissions: jobs: build: name: Generate Java docs - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 - name: Set up JDK 11 diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index cec4a52d39c93..a6749b42adc35 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -8,7 +8,7 @@ on: paths: - js/common/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: @@ -21,7 +21,7 @@ permissions: jobs: build: name: Generate JS API docs - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 - name: Setup Node.js diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index a8b81c8d5cf84..deef64f73f15a 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -8,7 +8,7 @@ on: paths: - objectivec/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 8b2f72d80bacf..adc2346d1bf1b 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -9,7 +9,7 @@ on: - onnxruntime/python/** - docs/python/** schedule: - - cron: '0 0 1 * *' + - cron: '0 0 1,15 * *' workflow_dispatch: concurrency: @@ -22,7 +22,7 @@ permissions: jobs: build: name: Generate Python API docs - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v4 - name: Install tools @@ -32,10 +32,10 @@ jobs: sudo apt-get install graphviz - name: Install dependencies run: | - python3 -m pip install --upgrade pip + python3 -m pip install --user --upgrade pip cd docs/python - python3 -m pip install -r requirements.txt - python3 -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html + python3 -m pip install --user -r requirements.txt + python3 -m pip install --user --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html python3 -m pip list - name: Generate Python docs with Sphinx run: | diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml index 0867d4c343e91..51166293f06ac 100644 --- a/.github/workflows/sca.yml +++ b/.github/workflows/sca.yml @@ -30,7 +30,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: 18 + node-version: 20 - name: Download cuda run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk @@ -57,6 +57,45 @@ jobs: sarif_file: ${{ github.workspace }}\output\MergeResult.sarif category: VS_SCA + # With WebGPU, Without python + Onnxruntime-SCA-win32-WebGPU-x64: + permissions: + security-events: write + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + steps: + - uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-python@v5 + with: + python-version: '3.11.x' + architecture: 'x64' + + - uses: actions/setup-node@v4 + with: + node-version: 20 + + - name: Delete build folder + run: | + if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } + + + - name: Build code + env: + CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' + run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_webgpu + + - name: Generate sarif + working-directory: D:\b + run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output + + - name: Upload SARIF to GitHub + uses: github/codeql-action/upload-sarif@v3 + continue-on-error: true + with: + sarif_file: ${{ github.workspace }}\output\MergeResult.sarif + category: VS_SCA_WIN32_WEBGPU_X64 + # No python Onnxruntime-SCA-win32-WINML-x64: permissions: @@ -73,7 +112,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: 18 + node-version: 20 - name: Delete build folder run: | @@ -113,7 +152,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: 18 + node-version: 20 - name: Delete build folder run: | diff --git a/.github/workflows/skip-doc-change.yml.j2 b/.github/workflows/skip-doc-change.yml.j2 index 58f048122a87e..04f77e5d28713 100644 --- a/.github/workflows/skip-doc-change.yml.j2 +++ b/.github/workflows/skip-doc-change.yml.j2 @@ -14,7 +14,7 @@ jobs: {%- for name in job_names %} job{{ loop.index }}: name: {{ name }} - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - run: 'echo "No build required, only documentation changed"' {% endfor %} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 181f3fb17d332..14cf0825873a0 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -8,7 +8,7 @@ on: jobs: close-stale-issues: - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] permissions: issues: write pull-requests: write diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml index e0af2dd06b1b7..7ee9f3917a901 100644 --- a/.github/workflows/title-only-labeler.yml +++ b/.github/workflows/title-only-labeler.yml @@ -8,7 +8,7 @@ permissions: jobs: triage: - runs-on: ubuntu-latest + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: github/issue-labeler@v3.4 with: diff --git a/.lintrunner.toml b/.lintrunner.toml index be46ba0baabdb..5ef9ad9337f57 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -2,31 +2,23 @@ # You can install the dependencies and initialize with # # ```sh -# pip install lintrunner lintrunner-adapters +# pip install -r requirements-lintrunner.txt # lintrunner init # ``` # # This will install lintrunner on your system and download all the necessary # dependencies to run linters locally. -# If you want to see what lintrunner init will install, run -# `lintrunner init --dry-run`. # -# To lint local changes: +# To format local changes: # # ```bash -# lintrunner +# lintrunner -a # ``` # -# To lint all files: +# To format all files: # # ```bash -# lintrunner --all-files -# ``` -# -# To format files: -# -# ```bash -# lintrunner f --all-files +# lintrunner -a --all-files # ``` # # To read more about lintrunner, see [wiki](https://github.com/pytorch/pytorch/wiki/lintrunner). diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 294bd926a34cb..b9932eb563b83 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 3528545dfb06e..37fe2d378b7fd 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/CODEOWNERS b/CODEOWNERS index f7dfa419500d0..a55067ed798d8 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -9,10 +9,6 @@ /onnxruntime/core/graph/contrib_ops/quantization_defs.* @microsoft/onnxruntime-mlas /onnxruntime/core/mlas/** @microsoft/onnxruntime-mlas -# build pipelines and workflows -/tools/ci_build/github/azure-pipelines @microsoft/onnxruntime-es -/.github/workflows @microsoft/onnxruntime-es - # Dependencies requirements-dev.txt @microsoft/onnxruntime-admin requirements-doc.txt @microsoft/onnxruntime-admin diff --git a/CPPLINT.cfg b/CPPLINT.cfg new file mode 100644 index 0000000000000..12c1c7be0d773 --- /dev/null +++ b/CPPLINT.cfg @@ -0,0 +1 @@ +filter=-whitespace diff --git a/README.md b/README.md index cde039cec52a8..f1817282b61a0 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ * **YouTube video tutorials**: [youtube.com/@ONNXRuntime](https://www.youtube.com/@ONNXRuntime) -* [**Upcoming Release Roadmap**](https://github.com/microsoft/onnxruntime/wiki/Upcoming-Release-Roadmap) +* [**Upcoming Release Roadmap**](https://onnxruntime.ai/roadmap) * **Companion sample repositories**: - ONNX Runtime Inferencing: [microsoft/onnxruntime-inference-examples](https://github.com/microsoft/onnxruntime-inference-examples) @@ -24,8 +24,8 @@ |System|Inference|Training| |---|---|---| -|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline?label=Windows+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)|| -|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining/orttraining-ortmodule-distributed?label=Training+Distributed)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=148)| +|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CUDA%20CI%20Pipeline?label=Windows+GPU+CUDA)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=218)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20WebGPU%20CI%20Pipeline?label=Windows+GPU+WebGPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=228)|| +|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)| |Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)|| |Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)|| |iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)|| @@ -40,6 +40,12 @@ This project is tested with [BrowserStack](https://www.browserstack.com/home). |---|---|---| |Linux|[![Build Status](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml)|| +## Releases + +The current release and past releases can be found here: https://github.com/microsoft/onnxruntime/releases. + +For details on the upcoming release, including release dates, announcements, features, and guidance on submitting feature requests, please visit the release roadmap: https://onnxruntime.ai/roadmap. + ## Data/Telemetry Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details. diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 6a11f414361bd..26084ab42ec1c 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -2108,261 +2108,6 @@ SOFTWARE. _____ -TVM Open Deep Learning Compiler Stack - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -CONTRIBUTORS - -TVM Contributors -================ -TVM adopts the Apache style model and governs by merit. We believe that it is important to create an inclusive community where everyone can use, -contribute to, and influence the direction of the project. We actively invite contributors who have earned the merit to be part of the development community. - -See the [community structure document](http://docs.tvm.ai/contribute/community.html) for the explanation of community structure and contribution guidelines. - -## Committers -- [Tianqi Chen](https://github.com/tqchen) (PMC) -- [Thierry Moreau](http://homes.cs.washington.edu/~moreau/) -- [Ziheng Jiang](https://github.com/ZihengJiang) -- [Haichen Shen](http://homes.cs.washington.edu/~haichen/) -- [Yizhi Liu](https://github.com/yzhliu) - -## Code Owners -- [Aditya Atluri](https://github.com/adityaatluri) ROCM -- [Leyuan Wang](https://github.com/Laurawly) TOPI -- [Yuwei Hu](https://github.com/Huyuwei) TOPI -- [Zhixun Tan](https://github.com/phisiart) OpenGL/WebGL backend -- [Nick Hynes](https://github.com/nhynes) SGX and secured computing -- [Lianmin Zheng](https://github.com/merrymercy) AutoTVM - -## Reviewers -- [Zhi Chen](https://github.com/zhiics) -- [Xiaoqiang Dan](https://github.com/xqdan) -- [Liangfu Chen](https://github.com/liangfu) -- [Masahiro Masuda](https://github.com/masahi) -- [Kazutaka Morita](https://github.com/kazum) -- [Tatsuya Nishiyama](https://github.com/nishi-t) -- [Pariksheet Pinjari](https://github.com/PariksheetPinjari909) -- [Jared Roesch](https://github.com/jroesch) -- [Siva](https://github.com/srkreddy1238) -- [Siju Samuel](https://github.com/siju-samuel) -- [Alex Weaver](https://github.com/alex-weaver) -- [Yao Wang](https://github.com/kevinthesun) -- [Jian Weng](https://github.com/were) -- [Eddie Yan](https://github.com/eqy) -- [Joshua Z. Zhang](https://github.com/zhreshold) - -## List of Contributors -- [Full List of Contributors](https://github.com/dmlc/tvm/graphs/contributors) - - To contributors: please add your name to the list. -- [Qiao Zhang](https://github.com/zhangqiaorjc) -- [Haolong Zhang](https://github.com/haolongzhangm) -- [Cody Hao Yu](https://github.com/comaniac) -- [Chris Nuernberger](https://github.com/cnuernber) - -_____ - FreeBSD: getopt.c file Copyright (c) 1987, 1993, 1994 @@ -2492,212 +2237,6 @@ DAMAGE. _____ -google/nsync - -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -_____ - google/re2 Copyright (c) 2009 The RE2 Authors. All rights reserved. diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 3989355915568..3500250a4b05b 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.20.0 +1.21.0 diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 1432193ac9080..46349f43923e2 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -1,578 +1,508 @@ { - "$schema": "https://json.schemastore.org/component-detection-manifest.json", - "Registrations": [ - { - "component": { - "type": "git", - "git": { - "commitHash": "215105818dfde3174fe799600bb0f3cae233d0bf", - "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" - } - } - }, - { - "component": { - "Type": "maven", - "maven": { - "GroupId": "org.junit.platform", - "ArtifactId": "junit-platform-console-standalone", - "Version": "1.6.2" - }, - "DevelopmentDependency": true - } - }, - { - "component": { - "Type": "maven", - "maven": { - "GroupId": "com.google.protobuf", - "ArtifactId": "protobuf-java", - "Version": "3.21.7" - }, - "DevelopmentDependency": true - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "2379917985919ed3918dc12cad47f469f245be7a", - "repositoryUrl": "https://github.com/apache/tvm.git" - }, - "comments": "needed for TVM EP" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "cabe04d6d6b05356fa8f9741704924788f0dd762", - "repositoryUrl": "https://github.com/agauniyal/rang.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "a3bcc6981d5dad3afb212689e2c7853d1b1ee45d", - "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "08f7c7e69f8ea61a0c4151359bc8023be8e9217b", - "repositoryUrl": "https://github.com/tlc-pack/libbacktrace.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "36a91576edf633479c78649e050f18dd2ddc8103", - "repositoryUrl": "https://github.com/apache/incubator-tvm-vta.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "111c9be5188f7350c2eac9ddaedd8cca3d7bf394", - "repositoryUrl": "https://github.com/kazuho/picojson.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b5e4186d7ab63458e79084842dced166be2ca5b5", - "repositoryUrl": "https://github.com/lammertb/libcrc.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e4a4c02764d37c9c3db0d64c4996651a3ef9513c", - "repositoryUrl": "https://github.com/dmlc/HalideIR.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "4d49691f1a9d944c3b0aa5e63f1db3cad1f941f8", - "repositoryUrl": "https://github.com/dmlc/dmlc-core.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7de7e5d02bf687f971e7668963649728356e0c20", - "repositoryUrl": "https://github.com/intel/mkl-dnn.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "d860915b0198ddb96f93e9e97a789af156544dc6", - "repositoryUrl": "https://github.com/tensorflow/tensorflow.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "eddf9023206dc40974c26f589ee2ad63a4227a1e", - "repositoryUrl": "https://github.com/glennrp/libpng.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "217f52fb121ef92491e5d5f71394b07ce4ead1d0", - "repositoryUrl": "https://github.com/KjellKod/g3log.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "50893291621658f355bc5b4d450a8d06a563053d", - "repositoryUrl": "https://github.com/madler/zlib.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "d264a2603493fecda607c1d1cda87fedba77d36b", - "repositoryUrl": "https://github.com/Microsoft/CNTK.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "971e2e89d08deeae0139d3011d15646fdac13c92", - "repositoryUrl": "https://github.com/numpy/numpy.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "90537289a04ef5d572496240e2ac3a881be518d2", - "repositoryUrl": "https://github.com/pytorch/pytorch.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b31f58de6fa8bbda5353b3c77d9be4914399724d", - "repositoryUrl": "https://github.com/pytorch/pytorch.git" - }, - "comments": "pytorch 1.6 used by onnxruntime training image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7389dbac82d362f296dc2746f10e43ffa1615660", - "repositoryUrl": "https://github.com/scikit-learn/scikit-learn.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "eeebdab16155d34ff8f5f42137da7df4d1c7eab0", - "repositoryUrl": "https://github.com/BVLC/caffe.git" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "LLVM", - "Version": "9.0.0", - "DownloadUrl": "https://releases.llvm.org/9.0.0/llvm-9.0.0.src.tar.xz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "FreeBSD GetOpt", - "Version": "12.0.0", - "DownloadUrl": "https://svnweb.freebsd.org/base/release/12.0.0/lib/libc/stdlib/getopt.c?revision=341707&view=co" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "Boost", - "Version": "1.69.0", - "DownloadUrl": "https://boostorg.jfrog.io/artifactory/main/release/1.69.0/source/boost_1_69_0.tar.bz2" - } - } - }, - { - "component": { - "git": { - "commitHash": "02a2a458ac15912d7d87cc1171e811b0c5219ece", - "repositoryUrl": "https://github.com/grpc/grpc" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "b29b21a81b32ec273f118f589f46d56ad3332420", - "repositoryUrl": "https://github.com/google/boringssl.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "3be1924221e1326df520f8498d704a5c4c8d0cce", - "repositoryUrl": "https://github.com/c-ares/c-ares.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "6599cac0965be8e5a835ab7a5684bbef033d5ad0", - "repositoryUrl": "https://github.com/llvm-mirror/libcxx.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "9245d481eb3e890f708ff2d7dadf2a10c04748ba", - "repositoryUrl": "https://github.com/llvm-mirror/libcxxabi.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "9ce4a77f61c134bbed28bfd5be5cd7dc0e80f5e3", - "repositoryUrl": "https://github.com/google/upb.git" - }, - "type": "git" - } - }, - { - "component": { - "type": "other", - "Other": { - "Name": "Go", - "Version": "1.12.6", - "DownloadUrl": "https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "OpenMPI", - "Version": "4.0.0", - "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "OpenMPI", - "Version": "4.0.4", - "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.4.tar.gz" - }, - "comments": "openmpi 4.0.4 used by onnxruntime training image" - } - }, - { - "component": { - "Type": "git", - "git": { - "commitHash": "7db3f9c741d3dfd8dda14ffb537ed251280d2025", - "repositoryUrl": "https://github.com/mpi4py/mpi4py" - }, - "comments": "mpi4py 3.0.3 used by onnxruntime training image" - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "NCCL", - "Version": "2.4.8", - "DownloadUrl": "https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "67afac65ce64fd4dce1494f43e565e8fe34bdffb", - "repositoryUrl": "https://android.googlesource.com/platform/frameworks/ml" - }, - "comments": "used by onnxruntime" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c30b7da2301202da5f9f0529966944f110e5d6e7", - "repositoryUrl": "https://github.com/openucx/ucx" - }, - "comments": "middleware between IB verbs and OpenMPI used by onnxruntime training image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "63d1e08e64e7e09408eb63cd8dd7c65ad766f277", - "repositoryUrl": "https://github.com/nodejs/node" - }, - "comments": "For Nodejs binding" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "aead4d751c2101e23336aa73f2380df83e7a13f3", - "repositoryUrl": "https://github.com/pypa/manylinux" - }, - "comments": "For building our CI build docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c974557598645360fbabac71352b083117e3cc17", - "repositoryUrl": "https://gitlab.kitware.com/cmake/cmake" - }, - "comments": "CMake 3.24.3. For building our CI build docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "1e5d33e9b9b8631b36f061103a30208b206fd03a", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.9.1" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6503f05dd59e26a9986bdea097b3da9b3546f45b", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.8.7" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "13c94747c74437e594b7fc242ff7da668e81887c", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.7.9" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c0a9afe2ac1820409e6173bd1893ebee2cf50270", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.6.12" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "426b022776672fdf3d71ddd98d89af341c88080f", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.5.10" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "transformers", - "Version": "4.38.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "msgpack", - "Version": "1.0.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "tensorboardX", - "Version": "1.8" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "tensorboard", - "Version": "2.3.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "92cf3702fcfaadc84eb7bef59825a23e0cd84f56", - "repositoryUrl": "https://github.com/aappleby/smhasher" - }, - "comments": "MurmurHash3" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b89da3c5a0aa18fb2c6163ad9984f81ab65b22e3", - "repositoryUrl": "https://github.com/mestevens/gtest-ios-framework" - }, - "comments": "gtest-ios-framework" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - }, - "comments": "dlpack" - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "SQLite3", - "Version": "3.22.0", - "DownloadUrl": "http://security.ubuntu.com/ubuntu/pool/main/s/sqlite3/libsqlite3-dev_3.22.0-1ubuntu0.4_amd64.deb" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "9d0ef119d9fcb9139f831adc224857b791c81140", - "repositoryUrl": "https://github.com/dlfcn-win32/dlfcn-win32.git" - }, - "comments": "dlfcn-win32" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6812205f18ca4ef54372e87e1a13ce4a859434df", - "repositoryUrl": "https://github.com/python-pillow/Pillow.git" - }, - "comments": "python-pillow. Implementation logic for anti-aliasing copied by Resize CPU kernel." - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", - "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" - } - } - } - ], - "Version": 1 + "$schema": "https://json.schemastore.org/component-detection-manifest.json", + "Registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "215105818dfde3174fe799600bb0f3cae233d0bf", + "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" + } + } + }, + { + "component": { + "Type": "maven", + "maven": { + "GroupId": "org.junit.platform", + "ArtifactId": "junit-platform-console-standalone", + "Version": "1.6.2" + }, + "DevelopmentDependency": true + } + }, + { + "component": { + "Type": "maven", + "maven": { + "GroupId": "com.google.protobuf", + "ArtifactId": "protobuf-java", + "Version": "3.21.7" + }, + "DevelopmentDependency": true + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "e4a4c02764d37c9c3db0d64c4996651a3ef9513c", + "repositoryUrl": "https://github.com/dmlc/HalideIR.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", + "repositoryUrl": "https://github.com/dmlc/dlpack.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4d49691f1a9d944c3b0aa5e63f1db3cad1f941f8", + "repositoryUrl": "https://github.com/dmlc/dmlc-core.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7de7e5d02bf687f971e7668963649728356e0c20", + "repositoryUrl": "https://github.com/intel/mkl-dnn.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d860915b0198ddb96f93e9e97a789af156544dc6", + "repositoryUrl": "https://github.com/tensorflow/tensorflow.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "eddf9023206dc40974c26f589ee2ad63a4227a1e", + "repositoryUrl": "https://github.com/glennrp/libpng.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "217f52fb121ef92491e5d5f71394b07ce4ead1d0", + "repositoryUrl": "https://github.com/KjellKod/g3log.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "50893291621658f355bc5b4d450a8d06a563053d", + "repositoryUrl": "https://github.com/madler/zlib.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d264a2603493fecda607c1d1cda87fedba77d36b", + "repositoryUrl": "https://github.com/Microsoft/CNTK.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "971e2e89d08deeae0139d3011d15646fdac13c92", + "repositoryUrl": "https://github.com/numpy/numpy.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "90537289a04ef5d572496240e2ac3a881be518d2", + "repositoryUrl": "https://github.com/pytorch/pytorch.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b31f58de6fa8bbda5353b3c77d9be4914399724d", + "repositoryUrl": "https://github.com/pytorch/pytorch.git" + }, + "comments": "pytorch 1.6 used by onnxruntime training image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7389dbac82d362f296dc2746f10e43ffa1615660", + "repositoryUrl": "https://github.com/scikit-learn/scikit-learn.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "eeebdab16155d34ff8f5f42137da7df4d1c7eab0", + "repositoryUrl": "https://github.com/BVLC/caffe.git" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "LLVM", + "Version": "9.0.0", + "DownloadUrl": "https://releases.llvm.org/9.0.0/llvm-9.0.0.src.tar.xz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "FreeBSD GetOpt", + "Version": "12.0.0", + "DownloadUrl": "https://svnweb.freebsd.org/base/release/12.0.0/lib/libc/stdlib/getopt.c?revision=341707&view=co" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "Boost", + "Version": "1.69.0", + "DownloadUrl": "https://boostorg.jfrog.io/artifactory/main/release/1.69.0/source/boost_1_69_0.tar.bz2" + } + } + }, + { + "component": { + "git": { + "commitHash": "02a2a458ac15912d7d87cc1171e811b0c5219ece", + "repositoryUrl": "https://github.com/grpc/grpc" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "b29b21a81b32ec273f118f589f46d56ad3332420", + "repositoryUrl": "https://github.com/google/boringssl.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "3be1924221e1326df520f8498d704a5c4c8d0cce", + "repositoryUrl": "https://github.com/c-ares/c-ares.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "6599cac0965be8e5a835ab7a5684bbef033d5ad0", + "repositoryUrl": "https://github.com/llvm-mirror/libcxx.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "9245d481eb3e890f708ff2d7dadf2a10c04748ba", + "repositoryUrl": "https://github.com/llvm-mirror/libcxxabi.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "9ce4a77f61c134bbed28bfd5be5cd7dc0e80f5e3", + "repositoryUrl": "https://github.com/google/upb.git" + }, + "type": "git" + } + }, + { + "component": { + "type": "other", + "Other": { + "Name": "Go", + "Version": "1.12.6", + "DownloadUrl": "https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "OpenMPI", + "Version": "4.0.0", + "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "OpenMPI", + "Version": "4.0.4", + "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.4.tar.gz" + }, + "comments": "openmpi 4.0.4 used by onnxruntime training image" + } + }, + { + "component": { + "Type": "git", + "git": { + "commitHash": "7db3f9c741d3dfd8dda14ffb537ed251280d2025", + "repositoryUrl": "https://github.com/mpi4py/mpi4py" + }, + "comments": "mpi4py 3.0.3 used by onnxruntime training image" + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "NCCL", + "Version": "2.4.8", + "DownloadUrl": "https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "67afac65ce64fd4dce1494f43e565e8fe34bdffb", + "repositoryUrl": "https://android.googlesource.com/platform/frameworks/ml" + }, + "comments": "used by onnxruntime" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c30b7da2301202da5f9f0529966944f110e5d6e7", + "repositoryUrl": "https://github.com/openucx/ucx" + }, + "comments": "middleware between IB verbs and OpenMPI used by onnxruntime training image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "63d1e08e64e7e09408eb63cd8dd7c65ad766f277", + "repositoryUrl": "https://github.com/nodejs/node" + }, + "comments": "For Nodejs binding" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "aead4d751c2101e23336aa73f2380df83e7a13f3", + "repositoryUrl": "https://github.com/pypa/manylinux" + }, + "comments": "For building our CI build docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c974557598645360fbabac71352b083117e3cc17", + "repositoryUrl": "https://gitlab.kitware.com/cmake/cmake" + }, + "comments": "CMake 3.24.3. For building our CI build docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "1e5d33e9b9b8631b36f061103a30208b206fd03a", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.9.1" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6503f05dd59e26a9986bdea097b3da9b3546f45b", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.8.7" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "13c94747c74437e594b7fc242ff7da668e81887c", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.7.9" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c0a9afe2ac1820409e6173bd1893ebee2cf50270", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.6.12" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "426b022776672fdf3d71ddd98d89af341c88080f", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.5.10" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "transformers", + "Version": "4.38.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "msgpack", + "Version": "1.0.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "tensorboardX", + "Version": "1.8" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "tensorboard", + "Version": "2.3.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "92cf3702fcfaadc84eb7bef59825a23e0cd84f56", + "repositoryUrl": "https://github.com/aappleby/smhasher" + }, + "comments": "MurmurHash3" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b89da3c5a0aa18fb2c6163ad9984f81ab65b22e3", + "repositoryUrl": "https://github.com/mestevens/gtest-ios-framework" + }, + "comments": "gtest-ios-framework" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", + "repositoryUrl": "https://github.com/dmlc/dlpack.git" + }, + "comments": "dlpack" + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "SQLite3", + "Version": "3.22.0", + "DownloadUrl": "http://security.ubuntu.com/ubuntu/pool/main/s/sqlite3/libsqlite3-dev_3.22.0-1ubuntu0.4_amd64.deb" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "9d0ef119d9fcb9139f831adc224857b791c81140", + "repositoryUrl": "https://github.com/dlfcn-win32/dlfcn-win32.git" + }, + "comments": "dlfcn-win32" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6812205f18ca4ef54372e87e1a13ce4a859434df", + "repositoryUrl": "https://github.com/python-pillow/Pillow.git" + }, + "comments": "python-pillow. Implementation logic for anti-aliasing copied by Resize CPU kernel." + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", + "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" + } + } + } + ], + "Version": 1 } diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 55d5fae4dedcd..475f75b5bf19b 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "f46495ea96f68fc3f6c394f099b2992743f6ff7f", + "commitHash": "4447c7562e3bc702ade25105912dce503f0c4010", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -122,16 +122,6 @@ "comments": "google_benchmark" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "13de152c2a1cd73ff4df97bd2c406b6d15d34af3", - "repositoryUrl": "https://github.com/google/nsync.git" - }, - "comments": "google_nsync" - } - }, { "component": { "type": "git", @@ -206,7 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7", + "commitHash": "bc0d2e35909b8456abe32f3b30a49bb0c125e8b7", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" @@ -356,7 +346,7 @@ "component": { "type": "git", "git": { - "commitHash": "511eb80847afe6bded34ec491a38d5d78ba2d604", + "commitHash": "12a3b24c456cebd9fd11f23ac0164f78129b00c6", "repositoryUrl": "https://github.com/google/dawn.git" }, "comments": "dawn" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c0772c33f6e5d..d2fe7e7457983 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -86,7 +86,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) # use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF) -option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) @@ -102,10 +102,10 @@ option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to provide eigen_SOURCE_PATH if turn this on." OFF) option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) -option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -128,6 +128,10 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) + +# When loading a delay loaded DLL, Windows searches the main EXE's folder first. +# In a Python process, it searches where python.exe lives, but it doesn't search the python package's installation folder. Therefore we cannot enable this flag when Python is enabled. +cmake_dependent_option(onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS "Delay load some of the dependent DLls that are part of the OS" ON "WIN32;NOT GDK_PLATFORM;NOT onnxruntime_ENABLE_PYTHON" OFF) option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) @@ -140,13 +144,15 @@ option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF) cmake_dependent_option(onnxruntime_USE_MIMALLOC "Override new/delete and arena allocator with mimalloc" OFF "WIN32;NOT onnxruntime_USE_CUDA;NOT onnxruntime_USE_OPENVINO" OFF) option(onnxruntime_USE_CANN "Build with CANN support" OFF) option(onnxruntime_USE_ROCM "Build with AMD GPU support" OFF) -option(onnxruntime_USE_TVM "Build with TVM support" OFF) -option(onnxruntime_TVM_CUDA_RUNTIME "Build TVM with CUDA support" OFF) -option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llvm-config.exe here if need" OFF) -option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) +option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF) +option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.") +option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF) +# The following 2 options are only for Windows +option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF) +option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -199,6 +205,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF) +option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF) # Enable bitcode for iOS option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF) @@ -250,6 +257,7 @@ cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when usi option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) +option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF) # ENABLE_TRAINING includes all training functionality # The following 2 entry points @@ -289,12 +297,50 @@ if (onnxruntime_USE_ROCM) message(FATAL_ERROR "ROCM does not support build with CUDA!") endif() + # replicate strategy used by pytorch to get ROCM_VERSION + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") + file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm_version.h ****\n") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h ****\n") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) + set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) + set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) + set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + + message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") + message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") + message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") + message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") + message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") + endif() + + if (NOT CMAKE_HIP_COMPILER) set(CMAKE_HIP_COMPILER "${onnxruntime_ROCM_HOME}/llvm/bin/clang++") endif() if (NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a;gfx1030;gfx1100;gfx1101") + if (ROCM_VERSION_DEV VERSION_LESS "6.2") + message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2") + else() + set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201") + endif() endif() file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) @@ -326,35 +372,6 @@ if (onnxruntime_USE_ROCM) set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl) endif() - # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake - # with modification - if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") - file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - endif() - - if (ROCM_VERSION_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") - math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") - else() - message(FATAL_ERROR "Cannot determine ROCm version string") - endif() - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") - message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") - message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") - message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") - message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") - message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") message("\n***** HIP LANGUAGE CONFIG INFO ****\n") message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") message("CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}") @@ -751,21 +768,30 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_DISABLE_CONTRIB_OPS) set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0") set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") endif() + if (WIN32) + message( STATUS "Lean Attention unsupported in Windows") + set(onnxruntime_USE_LEAN_ATTENTION OFF) + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -779,6 +805,13 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + + if (onnxruntime_USE_LEAN_ATTENTION) + message( STATUS "Enable lean attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1) + endif() + if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) message( STATUS "Enable memory efficient attention for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) @@ -874,11 +907,6 @@ if (onnxruntime_USE_SNPE) list(APPEND ONNXRUNTIME_PROVIDER_NAMES snpe) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_SNPE=1) endif() -if (onnxruntime_USE_TVM) - list(APPEND ORT_PROVIDER_FLAGS -DUSE_TVM=1) - list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_TVM=1) - list(APPEND ONNXRUNTIME_PROVIDER_NAMES tvm) -endif() if (onnxruntime_USE_WINML) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WINML=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WINML=1) @@ -931,6 +959,18 @@ if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + list(APPEND ORT_PROVIDER_FLAGS -DBUILD_DAWN_MONOLITHIC_LIBRARY=1) + endif() + if (onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_VULKAN=1) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1) + endif() endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) @@ -946,6 +986,10 @@ if (onnxruntime_USE_LOCK_FREE_QUEUE) add_compile_definitions(USE_LOCK_FREE_QUEUE) endif() +if (onnxruntime_FORCE_GENERIC_ALGORITHMS) + add_compile_definitions(FORCE_GENERIC_ALGORITHMS) +endif() + if (onnxruntime_ENABLE_LAZY_TENSOR) # To support LazyTensor, ORT needs to call Python function from C/C++. # so onnxruntime_ENABLE_PYTHON is required. @@ -1065,8 +1109,6 @@ function(onnxruntime_set_compile_flags target_name) if (CMAKE_CXX_COMPILER_ID STREQUAL "IBMClang") target_compile_options(${target_name} PRIVATE "-Wno-unused-function") endif() - target_compile_definitions(${target_name} PUBLIC -DNSYNC_ATOMIC_CPP11) - onnxruntime_add_include_to_target(${target_name} nsync::nsync_cpp) endif() foreach(ORT_FLAG ${ORT_PROVIDER_FLAGS}) target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG}) @@ -1280,50 +1322,6 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -# TVM EP -if (onnxruntime_USE_TVM) - if (NOT TARGET tvm) - message(STATUS "Include TVM(*).") - include(tvm) - endif() - - # ipp-crypto - if (onnxruntime_TVM_USE_HASH) - message(STATUS "Include ipp-crypto(*).") - include(ipp-crypto) - endif() - - # TVM - if (onnxruntime_TVM_USE_LLVM) - set(USE_LLVM "${onnxruntime_TVM_USE_LLVM}" CACHE STRING "Path to LLVM for correct TVM build") - elseif(onnxruntime_USE_LLVM) - set(USE_LLVM ON CACHE BOOL "Only defined for TVM") - endif() - - if (onnxruntime_TVM_CUDA_RUNTIME) - set(USE_CUDA ON CACHE BOOL "Only defined for TVM" FORCE) - endif() - - # TODO(vvchernov): customized tvm logger is hidden due to the issue on TVM side (https://github.com/apache/tvm/issues/10139) - # add_compile_definitions(TVM_LOG_CUSTOMIZE=1) - # add_library(tvm_custom_logger STATIC ${ONNXRUNTIME_ROOT}/core/providers/tvm/custom_logging.cc) - - set(USE_OPENMP gnu CACHE STRING "Only defined for TVM") - add_subdirectory(${tvm_SOURCE_DIR} ${tvm_BINARY_DIR} EXCLUDE_FROM_ALL) - - set_target_properties(tvm PROPERTIES FOLDER ${tvm_SOURCE_DIR}) - # target_link_libraries(tvm PUBLIC tvm_custom_logger) - - set(TVM_INCLUDES ${tvm_SOURCE_DIR}/include - ${tvm_SOURCE_DIR}/3rdparty/dmlc-core/include - ${tvm_SOURCE_DIR}/3rdparty/dlpack/include - $) - - set(onnxruntime_tvm_libs onnxruntime_providers_tvm) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm) - list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm) -endif() - # onnxruntime-extensions if (onnxruntime_USE_EXTENSIONS) include(extensions) @@ -1334,7 +1332,7 @@ endif() #Adjust warning flags set_msvc_c_cpp_compiler_warning_level(4) -set(onnxruntime_DELAYLOAD_FLAGS "") +set(onnxruntime_DELAYLOAD_FLAGS ) include_directories( ${ONNXRUNTIME_INCLUDE_DIR} @@ -1352,6 +1350,7 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DUSE_OPENVINO=1) if(onnxruntime_NPU_NO_FALLBACK) + add_definitions(-DOPENVINO_CONFIG_NPU=1) add_definitions(-DOPENVINO_DISABLE_NPU_FALLBACK=1) endif() @@ -1654,7 +1653,6 @@ if (WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES advapi32) endif() else() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ICONV_LIB} ${CMAKE_DL_LIBS} Threads::Threads) endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index c04d67ea4ce3f..dbbf685346532 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -60,6 +60,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -DORT_WASM64") + string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64") + endif() + # Build WebAssembly with multi-threads support. if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") diff --git a/cmake/deps.txt b/cmake/deps.txt index c1bb7ffe98a06..ed41ad5b0ceb1 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/f46495ea96f68fc3f6c394f099b2992743f6ff7f.zip;0e2b6d1dc7f0a808d1e23f7dd985f7bc18d52cbc +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240722.0.zip;36ee53eb1466fb6e593fc5c286680de31f8a494a coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 @@ -27,7 +27,6 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 -google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 #xnnpack 2024.09.04 googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A @@ -37,8 +36,8 @@ microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.z mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.17.0.zip;13a60ac5217c104139ce0fd024f48628e7bcf5bc -# Use the latest commit of 10.4-GA-ORT-DDS -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7.zip;1d92137f424513bce20033ab4fb31cc0be8d1185 +# Use the latest commit of 10.6-GA-ORT-DDS +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bc0d2e35909b8456abe32f3b30a49bb0c125e8b7.zip;f233ae871ad82c023da62e5dd620639f00bc2d15 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -59,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1 -dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 +dawn;https://github.com/google/dawn/archive/12a3b24c456cebd9fd11f23ac0164f78129b00c6.zip;ad428f6dc16f1336d584f7bad5714e1097dafc43 kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index dda7c5ff19ba4..7b6e2141eeb1b 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -27,7 +27,7 @@ FetchContent_Declare( URL ${DEP_URL_abseil_cpp} URL_HASH SHA1=${DEP_SHA1_abseil_cpp} PATCH_COMMAND ${ABSL_PATCH_COMMAND} - FIND_PACKAGE_ARGS NAMES absl + FIND_PACKAGE_ARGS 20240722 NAMES absl ) onnxruntime_fetchcontent_makeavailable(abseil_cpp) diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index a4fb63b6a8377..e995e215432a2 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index 4230eb8f4259b..b388a01209f4e 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,10 +1,12 @@ -set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) +set(PATCH_CLANG ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) +set(PATCH_GFX12X ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx12x_support.patch) include(FetchContent) FetchContent_Declare(composable_kernel URL ${DEP_URL_composable_kernel} URL_HASH SHA1=${DEP_SHA1_composable_kernel} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_CLANG} && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} ) FetchContent_GetProperties(composable_kernel) diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index e03506de12728..3cfcdd4b04c62 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.2) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.4) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 339cded091b29..95dd438702a18 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -15,6 +15,7 @@ else () eigen URL ${DEP_URL_eigen} URL_HASH SHA1=${DEP_SHA1_eigen} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/eigen/eigen-edge.patch ) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 85746027d4e8c..aeaaa7b51d595 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -86,27 +86,6 @@ if (onnxruntime_BUILD_BENCHMARKS) onnxruntime_fetchcontent_makeavailable(google_benchmark) endif() -if (NOT WIN32) - FetchContent_Declare( - google_nsync - URL ${DEP_URL_google_nsync} - URL_HASH SHA1=${DEP_SHA1_google_nsync} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/nsync/nsync_1.26.0.patch - FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync - ) - #nsync tests failed on Mac Build - set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) - onnxruntime_fetchcontent_makeavailable(google_nsync) - - if (google_nsync_SOURCE_DIR) - add_library(nsync::nsync_cpp ALIAS nsync_cpp) - target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) - endif() - if(TARGET unofficial::nsync::nsync_cpp AND NOT TARGET nsync::nsync_cpp) - message(STATUS "Aliasing unofficial::nsync::nsync_cpp to nsync::nsync_cpp") - add_library(nsync::nsync_cpp ALIAS unofficial::nsync::nsync_cpp) - endif() -endif() if(onnxruntime_USE_MIMALLOC) FetchContent_Declare( @@ -636,17 +615,39 @@ if (onnxruntime_USE_COREML) endif() if (onnxruntime_USE_WEBGPU) - FetchContent_Declare( - dawn - URL ${DEP_URL_dawn} - URL_HASH SHA1=${DEP_SHA1_dawn} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch - ) + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + # use the custom dawn source path if provided + # + # specified as: + # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" + FetchContent_Declare( + dawn + SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} + ) + else() + FetchContent_Declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + # All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now. + # if we need to apply patches in the future, we can uncomment the following line. + # PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch + ) + endif() + + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) - # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size - set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + endif() + else() + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) + endif() set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) @@ -675,13 +676,34 @@ if (onnxruntime_USE_WEBGPU) set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) - # Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it. - set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + endif() + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() onnxruntime_fetchcontent_makeavailable(dawn) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc) + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn) + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native) + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc) + endif() endif() set(onnxruntime_LINK_DIRS) diff --git a/cmake/external/tvm.cmake b/cmake/external/tvm.cmake deleted file mode 100644 index 93049c8b85853..0000000000000 --- a/cmake/external/tvm.cmake +++ /dev/null @@ -1,24 +0,0 @@ -if (onnxruntime_USE_TVM) - message(STATUS "onnxruntime_USE_TVM: Fetch tvm for TVM EP") - - FetchContent_Declare( - tvm - GIT_REPOSITORY https://github.com/apache/tvm.git - GIT_TAG 2379917985919ed3918dc12cad47f469f245be7a - ) - - FetchContent_GetProperties(tvm) - if(NOT tvm_POPULATED) - FetchContent_Populate(tvm) - if (WIN32) - execute_process( - COMMAND ${CMAKE_COMMAND} -E create_symlink ${tvm_BINARY_DIR}/${CMAKE_BUILD_TYPE} ${tvm_SOURCE_DIR}/build - ) - else() - file(CREATE_LINK ${tvm_BINARY_DIR} ${tvm_SOURCE_DIR}/build SYMBOLIC) - endif() - endif() - - set(tvm_INCLUDE_DIRS ${tvm_SOURCE_DIR}/include) - -endif() diff --git a/cmake/hip_fatbin_insert b/cmake/hip_fatbin_insert new file mode 100644 index 0000000000000..7d834cbf569f0 --- /dev/null +++ b/cmake/hip_fatbin_insert @@ -0,0 +1,7 @@ +SECTIONS { + .hipFatBinSegment : { *(.hipFatBinSegment) } +} INSERT AFTER .bss + +SECTIONS { + .hip_fatbin : { *(.hip_fatbin) } +} INSERT AFTER .hipFatBinSegment diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 08b8ca0cb66de..732c0511d400f 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -122,8 +122,12 @@ else() else() onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c ) endif() - if (onnxruntime_USE_CUDA) - set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN") + if(NOT APPLE) + include(CheckLinkerFlag) + check_linker_flag(CXX "LINKER:-rpath=\$ORIGIN" LINKER_SUPPORT_RPATH) + if(LINKER_SUPPORT_RPATH) + target_link_options(onnxruntime PRIVATE "LINKER:-rpath=\$ORIGIN") + endif() endif() endif() @@ -139,17 +143,17 @@ target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\") if(UNIX) if (APPLE) - set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker -dead_strip") + target_link_options(onnxruntime PRIVATE "LINKER:-dead_strip") elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker --version-script=${SYMBOL_FILE} -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections") endif() else() - set(ONNXRUNTIME_SO_LINK_FLAG " -DEF:${SYMBOL_FILE}") + target_link_options(onnxruntime PRIVATE "-DEF:${SYMBOL_FILE}") endif() -if (NOT WIN32) - if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "^iOS") - set(ONNXRUNTIME_SO_LINK_FLAG " -Wl,-exported_symbols_list,${SYMBOL_FILE}") + +if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "^iOS") + target_link_options(onnxruntime PRIVATE "LINKER:-exported_symbols_list,${SYMBOL_FILE}") if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set_target_properties(onnxruntime PROPERTIES MACOSX_RPATH TRUE @@ -159,12 +163,10 @@ if (NOT WIN32) else() set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path") endif() - elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'") - endif() endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_MINIMAL_BUILD) # target onnxruntime is a shared library, the dummy __cxa_demangle is only attach to it to avoid # affecting downstream ort library users with the behavior of dummy __cxa_demangle. So the dummy @@ -208,7 +210,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_NNAPI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} - ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} @@ -219,7 +220,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${onnxruntime_winml} onnxruntime_optimizer onnxruntime_providers - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_graph @@ -248,7 +248,9 @@ target_link_libraries(onnxruntime PRIVATE ${onnxruntime_EXTERNAL_LIBRARIES} ) -set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS}) +if(WIN32) + target_link_options(onnxruntime PRIVATE ${onnxruntime_DELAYLOAD_FLAGS}) +endif() #See: https://cmake.org/cmake/help/latest/prop_tgt/SOVERSION.html if(NOT APPLE AND NOT WIN32) if(${CMAKE_SYSTEM_NAME} MATCHES "AIX") @@ -393,8 +395,23 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) list(APPEND lib_and_dependencies ${cur_target}) - get_target_property(link_libraries ${cur_target} LINK_LIBRARIES) - foreach(dependency ${link_libraries}) + set(all_link_libraries) + + get_property(link_libraries_set TARGET ${cur_target} PROPERTY LINK_LIBRARIES SET) + if(link_libraries_set) + get_target_property(link_libraries ${cur_target} LINK_LIBRARIES) + list(APPEND all_link_libraries ${link_libraries}) + endif() + + get_property(interface_link_libraries_set TARGET ${cur_target} PROPERTY INTERFACE_LINK_LIBRARIES SET) + if(interface_link_libraries_set) + get_target_property(interface_link_libraries ${cur_target} INTERFACE_LINK_LIBRARIES) + list(APPEND all_link_libraries ${interface_link_libraries}) + endif() + + list(REMOVE_DUPLICATES all_link_libraries) + + foreach(dependency ${all_link_libraries}) if(TARGET ${dependency}) process(${dependency}) endif() diff --git a/cmake/onnxruntime_codegen_tvm.cmake b/cmake/onnxruntime_codegen_tvm.cmake deleted file mode 100644 index 7b50d8f8603ae..0000000000000 --- a/cmake/onnxruntime_codegen_tvm.cmake +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -file(GLOB_RECURSE onnxruntime_codegen_common_srcs - "${ONNXRUNTIME_ROOT}/core/codegen/common/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/common/*.cc" -) - -file(GLOB_RECURSE onnxruntime_codegen_tvm_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/codegen/mti/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/mti/*.cc" - "${ONNXRUNTIME_ROOT}/core/codegen/passes/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/passes/*.cc" -) - -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_codegen_common_srcs} ${onnxruntime_codegen_tvm_srcs}) - -#onnxruntime_codegen_tvm depends on onnxruntime framework -onnxruntime_add_static_library(onnxruntime_codegen_tvm ${onnxruntime_codegen_common_srcs} ${onnxruntime_codegen_tvm_srcs}) -set_target_properties(onnxruntime_codegen_tvm PROPERTIES FOLDER "ONNXRuntime") -target_include_directories(onnxruntime_codegen_tvm PRIVATE ${ONNXRUNTIME_ROOT} ${TVM_INCLUDES} ${MKLML_INCLUDE_DIR} ${eigen_INCLUDE_DIRS}) -onnxruntime_add_include_to_target(onnxruntime_codegen_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11) -target_compile_options(onnxruntime_codegen_tvm PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) -# need onnx to build to create headers that this project includes -add_dependencies(onnxruntime_codegen_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_csharp.cmake b/cmake/onnxruntime_csharp.cmake index 22c993d07f7f9..39533429e181c 100644 --- a/cmake/onnxruntime_csharp.cmake +++ b/cmake/onnxruntime_csharp.cmake @@ -30,10 +30,6 @@ if (onnxruntime_USE_NNAPI_BUILTIN) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_NNAPI;") endif() -if (onnxruntime_USE_TVM) - STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_TVM,") -endif() - if (onnxruntime_USE_OPENVINO) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_OPENVINO;") endif() diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 765ebab111ac7..b15b9632e9e24 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -7,7 +7,7 @@ include(FindJava) find_package(Java REQUIRED) include(UseJava) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") +if (NOT ANDROID) find_package(JNI REQUIRED) endif() @@ -21,23 +21,28 @@ endif() set(GRADLE_EXECUTABLE "${JAVA_ROOT}/gradlew") +set(COMMON_GRADLE_ARGS --console=plain) +if(WIN32) + list(APPEND COMMON_GRADLE_ARGS -Dorg.gradle.daemon=false) +elseif (ANDROID) + # For Android build, we may run gradle multiple times in same build, + # sometimes gradle JVM will run out of memory if we keep the daemon running + # it is better to not keep a daemon running + list(APPEND COMMON_GRADLE_ARGS --no-daemon) +endif() + # Specify the Java source files file(GLOB_RECURSE onnxruntime4j_gradle_files "${JAVA_ROOT}/*.gradle") file(GLOB_RECURSE onnxruntime4j_src "${JAVA_ROOT}/src/main/java/ai/onnxruntime/*.java") set(JAVA_OUTPUT_JAR ${JAVA_ROOT}/build/libs/onnxruntime.jar) # this jar is solely used to signaling mechanism for dependency management in CMake # if any of the Java sources change, the jar (and generated headers) will be regenerated and the onnxruntime4j_jni target will be rebuilt -set(GRADLE_ARGS --console=plain clean jar -x test) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS clean jar -x test) -add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT} DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) +add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT} + DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) add_custom_target(onnxruntime4j DEPENDS ${JAVA_OUTPUT_JAR}) set_source_files_properties(${JAVA_OUTPUT_JAR} PROPERTIES GENERATED TRUE) set_property(TARGET onnxruntime4j APPEND PROPERTY ADDITIONAL_CLEAN_FILES "${JAVA_OUTPUT_DIR}") @@ -62,7 +67,7 @@ target_link_libraries(onnxruntime4j_jni PUBLIC onnxruntime) set(JAVA_PACKAGE_OUTPUT_DIR ${JAVA_OUTPUT_DIR}/build) file(MAKE_DIRECTORY ${JAVA_PACKAGE_OUTPUT_DIR}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_OUTPUT_DIR ${JAVA_PACKAGE_OUTPUT_DIR}/android) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_OUTPUT_DIR}) endif() @@ -88,7 +93,7 @@ if(APPLE) elseif(JNI_ARCH STREQUAL "arm64") set(JNI_ARCH aarch64) endif() -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") +elseif (ANDROID) set(JNI_ARCH ${ANDROID_ABI}) elseif (ARM64) set(JNI_ARCH aarch64) @@ -180,15 +185,7 @@ else() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) -set(GRADLE_ARGS --console=plain cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) # Append relevant native build flags to gradle command set(GRADLE_ARGS ${GRADLE_ARGS} ${ORT_PROVIDER_FLAGS}) @@ -197,9 +194,11 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() message(STATUS "GRADLE_ARGS: ${GRADLE_ARGS}") -add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT}) +add_custom_command(TARGET onnxruntime4j_jni POST_BUILD + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_JNILIBS_DIR ${JAVA_OUTPUT_DIR}/android) set(ANDROID_PACKAGE_ABI_DIR ${ANDROID_PACKAGE_JNILIBS_DIR}/${ANDROID_ABI}) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_JNILIBS_DIR}) @@ -214,6 +213,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Generating Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} build -b build-android.gradle -c settings-android.gradle -DjniLibsDir=${ANDROID_PACKAGE_JNILIBS_DIR} -DbuildDir=${ANDROID_PACKAGE_OUTPUT_DIR} @@ -237,6 +237,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Building and running Android test for Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} clean assembleDebug assembleDebugAndroidTest -DminSdkVer=${ANDROID_MIN_SDK} --stacktrace diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index 7de4f7b3f926b..62a6d45088052 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -64,7 +64,7 @@ elseif (onnxruntime_USE_ROCM) ) auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs}) target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs}) - target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1) + target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS_V2) if (onnxruntime_USE_COMPOSABLE_KERNEL) target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL) if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0ba4694c329e3..5124262ec0004 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -36,11 +36,13 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/rotary_embedding.h + ${MLAS_SRC_DIR}/rotary_embedding.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -84,11 +86,15 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -362,10 +368,12 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -383,7 +391,9 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -393,7 +403,9 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -453,7 +465,6 @@ else() bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); return 0; } - } #endif" HAS_P10_RUNTIME ) @@ -677,6 +688,13 @@ endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) endif() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() @@ -743,7 +761,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) endif() if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${CMAKE_DL_LIBS}) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 9666877cdc206..582491de9503d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -101,9 +101,6 @@ endif() if(onnxruntime_USE_ROCM) set(PROVIDERS_ROCM onnxruntime_providers_rocm) endif() -if (onnxruntime_USE_TVM) - set(PROVIDERS_TVM onnxruntime_providers_tvm) -endif() if (onnxruntime_USE_XNNPACK) set(PROVIDERS_XNNPACK onnxruntime_providers_xnnpack) endif() @@ -194,10 +191,6 @@ if (onnxruntime_USE_ROCM) include(onnxruntime_providers_rocm.cmake) endif() -if (onnxruntime_USE_TVM) - include(onnxruntime_providers_tvm.cmake) -endif() - if (onnxruntime_USE_VSINPU) include(onnxruntime_providers_vsinpu.cmake) endif() diff --git a/cmake/onnxruntime_providers_cann.cmake b/cmake/onnxruntime_providers_cann.cmake index 0e26f7ee3a57b..2b82379ed66a9 100644 --- a/cmake/onnxruntime_providers_cann.cmake +++ b/cmake/onnxruntime_providers_cann.cmake @@ -21,7 +21,7 @@ onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) + target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64) target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 774b7a4f6bd77..4f86717026118 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -224,8 +224,7 @@ include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} - PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") @@ -275,10 +274,8 @@ if(APPLE) set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst") - target_link_libraries(${target} PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections") - target_link_libraries(${target} PRIVATE nsync::nsync_cpp) elseif(WIN32) set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def") else() diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake index 439be882dcc5e..3141aa85a1163 100644 --- a/cmake/onnxruntime_providers_dml.cmake +++ b/cmake/onnxruntime_providers_dml.cmake @@ -61,8 +61,9 @@ target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) - if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:dxcore.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") + if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS AND NOT GDK_PLATFORM) + #NOTE: the flags are only applied to onnxruntime.dll and the PYD file in our python package. Our C/C++ unit tests do not use these flags. + list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:DirectML.dll" "/DELAYLOAD:d3d12.dll" "/DELAYLOAD:dxgi.dll" "/DELAYLOAD:dxcore.dll" "/DELAYLOAD:api-ms-win-core-com-l1-1-0.dll" "/DELAYLOAD:shlwapi.dll" "/DELAYLOAD:oleaut32.dll" "/DELAYLOAD:ext-ms-win-dxcore-l1-*.dll" "/ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml diff --git a/cmake/onnxruntime_providers_dnnl.cmake b/cmake/onnxruntime_providers_dnnl.cmake index f2965728524b7..9e5a7eed44fff 100644 --- a/cmake/onnxruntime_providers_dnnl.cmake +++ b/cmake/onnxruntime_providers_dnnl.cmake @@ -41,10 +41,8 @@ INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE) - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN") - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) elseif(WIN32) set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def") else() diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index d7d83b0ce8d64..685e77bc483bd 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -57,7 +57,7 @@ endif() if(UNIX) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 2eb3611bae902..f5fae8d169ccc 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -11,22 +11,22 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - if (WIN32) - set(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release) - endif() - # Header paths find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - if(OpenVINO_VERSION VERSION_LESS 2024.0) - message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") + if(OpenVINO_VERSION VERSION_LESS 2024.4) + message(FATAL_ERROR "OpenVINO 2024.4 and newer are supported. Please, use latest OpenVINO release") endif() if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) add_definitions(-DUSE_OVEP_NPU_MEMORY=1) endif() - if (WIN32) - unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) + # If building RelWithDebInfo and OV package does not have that configuration map to Release + get_target_property(ov_rt_implib_rwdi openvino::runtime IMPORTED_IMPLIB_RELWITHDEBINFO) + if ((CMAKE_BUILD_TYPE STREQUAL RelWithDebInfo) AND NOT ov_rt_implib_rwdi) + set_target_properties(openvino::runtime PROPERTIES + MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release + ) endif() list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES}) @@ -37,7 +37,7 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") - onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx) + onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx nlohmann_json::nlohmann_json) install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/openvino/openvino_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) set_target_properties(onnxruntime_providers_openvino PROPERTIES CXX_STANDARD 20) @@ -82,3 +82,8 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + +set_target_properties(onnxruntime_providers_openvino PROPERTIES + MAP_IMPORTED_CONFIG_RELEASE RelWithDebInfo + MAP_IMPORTED_CONFIG_DEBUG RelWithDebInfo + ) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 559204bd0df88..68f5319c0ae8d 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -8,7 +8,7 @@ find_package(HIP) find_package(hiprand REQUIRED) - find_package(rocblas REQUIRED) + find_package(hipblas REQUIRED) find_package(MIOpen REQUIRED) find_package(hipfft REQUIRED) @@ -50,7 +50,7 @@ find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) find_package(rocm_smi REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB}) + set(ONNXRUNTIME_ROCM_LIBS roc::hipblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB}) include_directories(${ROCM_SMI_INCLUDE_DIR}) link_directories(${ROCM_SMI_LIB_DIR}) @@ -116,6 +116,7 @@ auto_set_source_files_hip_language(${onnxruntime_providers_rocm_src}) onnxruntime_add_shared_library_module(onnxruntime_providers_rocm ${onnxruntime_providers_rocm_src}) target_compile_options(onnxruntime_providers_rocm PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) + target_link_options(onnxruntime_providers_rocm PRIVATE -T ${REPO_ROOT}/cmake/hip_fatbin_insert) if(NOT MSVC) target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-sign-compare) @@ -154,6 +155,7 @@ set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime") + target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS_V2) if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS}) @@ -215,7 +217,6 @@ if(UNIX) set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp) else() message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") endif() diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 468aaa44ec4ee..7b18222f334f9 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -206,11 +206,9 @@ if(APPLE) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp) elseif(WIN32) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def") else() diff --git a/cmake/onnxruntime_providers_tvm.cmake b/cmake/onnxruntime_providers_tvm.cmake deleted file mode 100644 index 8fd50c70dd5d7..0000000000000 --- a/cmake/onnxruntime_providers_tvm.cmake +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - - add_definitions(-DUSE_TVM=1) - if (onnxruntime_TVM_USE_HASH) - add_definitions(-DUSE_TVM_HASH=1) - endif() - - if (onnxruntime_TVM_USE_HASH) - file (GLOB_RECURSE onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - else() - file (GLOB onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - endif() - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tvm_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_tvm ${onnxruntime_providers_tvm_cc_srcs}) - - if ( CMAKE_COMPILER_IS_GNUCC ) - target_compile_options(onnxruntime_providers_tvm PRIVATE -Wno-unused-parameter -Wno-missing-field-initializers) - endif() - - target_include_directories(onnxruntime_providers_tvm PRIVATE - ${TVM_INCLUDES} - ${PYTHON_INCLUDE_DIRS}) - onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - - add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - if (onnxruntime_TVM_USE_HASH) - add_dependencies(onnxruntime_providers_tvm ippcp_s) - target_include_directories(onnxruntime_providers_tvm PRIVATE ${IPP_CRYPTO_INCLUDE_DIR}) - target_link_libraries(onnxruntime_providers_tvm PRIVATE ippcp_s) - endif() - - set_target_properties(onnxruntime_providers_tvm PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_tvm PROPERTIES LINKER_LANGUAGE CXX) - - if (WIN32 AND MSVC) - # wd4100: identifier' : unreferenced formal parameter - # wd4127: conditional expression is constant - # wd4244: conversion from 'int' to 'char', possible loss of data - # TODO: 4244 should not be disabled - target_compile_options(onnxruntime_providers_tvm PRIVATE "/wd4100" "/wd4127" "/wd4244") - else() - target_compile_options(onnxruntime_providers_tvm PRIVATE "-Wno-error=type-limits") - endif() - target_compile_definitions(onnxruntime_providers_tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) - - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tvm/tvm_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_tvm - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 764cde9491da8..561a323533f48 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -12,6 +12,7 @@ file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include/vaip/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" diff --git a/cmake/onnxruntime_providers_vsinpu.cmake b/cmake/onnxruntime_providers_vsinpu.cmake index 4b987fd1e424b..e3b6c3c302c82 100644 --- a/cmake/onnxruntime_providers_vsinpu.cmake +++ b/cmake/onnxruntime_providers_vsinpu.cmake @@ -11,7 +11,7 @@ add_library(onnxruntime_providers_vsinpu ${onnxruntime_providers_vsinpu_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_vsinpu onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11 - safeint_interface nsync::nsync_cpp) + safeint_interface ) add_dependencies(onnxruntime_providers_vsinpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_vsinpu PROPERTIES FOLDER "ONNXRuntime" LINKER_LANGUAGE CXX) target_include_directories(onnxruntime_providers_vsinpu PRIVATE ${ONNXRUNTIME_ROOT} $ENV{TIM_VX_INSTALL}/include) diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index eb25c55ab23e0..fea5964f0dda9 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -22,6 +22,25 @@ onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc) + + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) + list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") + endif() + + # Copy webgpu_dawn.dll to the output directory + add_custom_command( + TARGET onnxruntime_providers_webgpu + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$" + VERBATIM ) + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) + endif() + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc) + endif() set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 0d038d210ea2b..5a87252b08573 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -110,17 +110,17 @@ if (onnxruntime_USE_NCCL) endif() if(APPLE) - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker -exported_symbols_list -Xlinker ${ONNXRUNTIME_ROOT}/python/exported_symbols.lst") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-exported_symbols_list,${ONNXRUNTIME_ROOT}/python/exported_symbols.lst") elseif(UNIX) if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/version_script_expose_onnx_protobuf.lds -Xlinker --gc-sections") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:--version-script=${ONNXRUNTIME_ROOT}/python/version_script_expose_onnx_protobuf.lds" "LINKER:--gc-sections") else() if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/version_script.lds -Xlinker --gc-sections") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:--version-script=${ONNXRUNTIME_ROOT}/python/version_script.lds" "LINKER:--gc-sections") endif() endif() else() - set(ONNXRUNTIME_SO_LINK_FLAG "-DEF:${ONNXRUNTIME_ROOT}/python/pybind.def") + target_link_options(onnxruntime_pybind11_state PRIVATE "-DEF:${ONNXRUNTIME_ROOT}/python/pybind.def") endif() if (onnxruntime_ENABLE_ATEN) @@ -169,8 +169,8 @@ endif() target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} - ${PROVIDERS_TVM} ${PROVIDERS_NNAPI} + ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} ${PROVIDERS_RKNPU} @@ -184,7 +184,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_optimizer onnxruntime_providers onnxruntime_util - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_util @@ -199,11 +198,11 @@ set(onnxruntime_pybind11_state_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES} ${pybind11_dep} ) -set_property(TARGET onnxruntime_pybind11_state APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS}) + add_dependencies(onnxruntime_pybind11_state ${onnxruntime_pybind11_state_dependencies}) if (MSVC) - set_target_properties(onnxruntime_pybind11_state PROPERTIES LINK_FLAGS "${ONNXRUNTIME_SO_LINK_FLAG}") + target_link_options(onnxruntime_pybind11_state PRIVATE ${onnxruntime_DELAYLOAD_FLAGS}) # if MSVC, pybind11 undefines _DEBUG in pybind11/detail/common.h, which causes the pragma in pyconfig.h # from the python installation to require the release version of the lib # e.g. from a python 3.10 install: @@ -220,14 +219,15 @@ if (MSVC) # Explicitly use the release version of the python library to make the project file consistent with this. target_link_libraries(onnxruntime_pybind11_state PRIVATE ${Python_LIBRARY_RELEASE}) elseif (APPLE) - set_target_properties(onnxruntime_pybind11_state PROPERTIES LINK_FLAGS "${ONNXRUNTIME_SO_LINK_FLAG} -Xlinker -undefined -Xlinker dynamic_lookup") + # The following flag no longer works + #target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-undefined,dynamic_lookup") set_target_properties(onnxruntime_pybind11_state PROPERTIES INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE) else() if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - set_property(TARGET onnxruntime_pybind11_state APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-rpath=\$ORIGIN") endif() endif() @@ -238,8 +238,8 @@ if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) MATH(EXPR PROTOBUF_INDEX_NEXT "${PROTOBUF_INDEX} + 1") if (ONNX_INDEX GREATER_EQUAL 0 AND PROTOBUF_INDEX GREATER_EQUAL 0) # Expect protobuf to follow onnx due to dependence - list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${ONNX_INDEX} "-Wl,--no-as-needed") - list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${PROTOBUF_INDEX_NEXT} "-Wl,--as-needed") + list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${ONNX_INDEX} "LINKER:--no-as-needed") + list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${PROTOBUF_INDEX_NEXT} "LINKER:--as-needed") else() message(FATAL_ERROR "Required external libraries onnx and protobuf are not found in onnxruntime_EXTERNAL_LIBRARIES") endif() @@ -964,37 +964,6 @@ if (onnxruntime_USE_ROCM) ) endif() -if (onnxruntime_USE_TVM) - file(GLOB onnxruntime_python_providers_tvm_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/providers/tvm/*.py" - ) - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/providers - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/providers/tvm - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_providers_tvm_srcs} - $/onnxruntime/providers/tvm - COMMAND ${CMAKE_COMMAND} -E copy - $ - $/onnxruntime/capi/ - ) - - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - WORKING_DIRECTORY ${tvm_SOURCE_DIR}/python - COMMAND ${Python_EXECUTABLE} setup.py bdist_wheel - ) - - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${Python_EXECUTABLE} - $/onnxruntime/providers/tvm/extend_python_file.py - --target_file $/onnxruntime/capi/_ld_preload.py - ) - -endif() - if (onnxruntime_USE_DML) if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(dml_shared_lib_path ${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/${DML_SHARED_LIB}) @@ -1050,4 +1019,13 @@ if (onnxruntime_USE_QNN) endif() endif() +if (onnxruntime_USE_VSINPU) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) +endif() + endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index fcddd2a51e0d1..111033c780712 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -157,10 +157,6 @@ set(provider_excluded_files "cuda_execution_provider_info.h" "cuda_execution_provider.cc" "cuda_execution_provider.h" - "cuda_memory_check.cc" - "cuda_memory_check.h" - "cuda_fence.cc" - "cuda_fence.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 619c3a784d5f9..e822f0a3655fc 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -9,9 +9,6 @@ set(TEST_INC_DIR ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) list(APPEND TEST_INC_DIR ${ORTTRAINING_ROOT}) endif() -if (onnxruntime_USE_TVM) - list(APPEND TEST_INC_DIR ${TVM_INCLUDES}) -endif() set(disabled_warnings) function(AddTest) @@ -67,7 +64,10 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) + if(NOT onnxruntime_CUDA_MINIMAL) + target_link_libraries(${_UT_TARGET} PRIVATE cudnn_frontend) + endif() endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -111,7 +111,6 @@ function(AddTest) endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() - target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") if (${HAS_NOERROR}) @@ -134,9 +133,14 @@ function(AddTest) if (IOS) # target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m) + + set(_UT_IOS_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET}) + # replace any characters that are not valid in a bundle identifier with '-' + string(REGEX REPLACE "[^a-zA-Z0-9\\.-]" "-" _UT_IOS_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER}) + set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET} - MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET} + MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER} MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION} MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION} MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION} @@ -163,13 +167,31 @@ function(AddTest) set_target_properties(${_UT_TARGET}_xc PROPERTIES FOLDER "ONNXRuntimeXCTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}_xc - MACOSX_BUNDLE_GUI_IDENTIFIER com.onnxruntime.utest.${_UT_TARGET} + MACOSX_BUNDLE_GUI_IDENTIFIER ${_UT_IOS_BUNDLE_GUI_IDENTIFIER} MACOSX_BUNDLE_LONG_VERSION_STRING ${ORT_VERSION} MACOSX_BUNDLE_BUNDLE_VERSION ${ORT_VERSION} MACOSX_BUNDLE_SHORT_VERSION_STRING ${ORT_VERSION} XCODE_ATTRIBUTE_ENABLE_BITCODE "NO") - xctest_add_test(xctest.${_UT_TARGET} ${_UT_TARGET}_xc) + # This is a workaround for an Xcode 16 / CMake issue: + # error: Multiple commands produce '/Debug/Debug-iphonesimulator/onnxruntime_test_all.app/PlugIns' + # note: CreateBuildDirectory /Debug/Debug-iphonesimulator/onnxruntime_test_all.app/PlugIns + # note: Target 'onnxruntime_test_all' (project 'onnxruntime') has create directory command with output + # '/Debug/Debug-iphonesimulator/onnxruntime_test_all.app/PlugIns' + # + # It seems related to the test target (e.g., onnxruntime_test_all_xc) LIBRARY_OUTPUT_DIRECTORY property getting set + # to "$/PlugIns" in xctest_add_bundle(): + # https://github.com/Kitware/CMake/blob/9c4a0a9ff09735b847bbbc38caf6da7f6c7238f2/Modules/FindXCTest.cmake#L159-L168 + # + # This is the related CMake issue: https://gitlab.kitware.com/cmake/cmake/-/issues/26301 + # + # Unsetting LIBRARY_OUTPUT_DIRECTORY avoids the build error. + set_property(TARGET ${_UT_TARGET}_xc PROPERTY LIBRARY_OUTPUT_DIRECTORY) + + # Don't bother calling xctest_add_test() because we don't use CTest to run tests on iOS. + # Instead, we can call 'xcodebuild test-without-building' and specify a '-destination' referring to an iOS + # simulator or device. + # xctest_add_test(xctest.${_UT_TARGET} ${_UT_TARGET}_xc) else() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") # We might have already executed the following "find_program" code when we build ORT nodejs binding. @@ -500,6 +522,9 @@ set (onnxruntime_global_thread_pools_test_SRC ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc) +set (onnxruntime_webgpu_external_dawn_test_SRC + ${TEST_SRC_DIR}/webgpu/external_dawn/main.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -615,13 +640,11 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} ${PROVIDERS_COREML} - # ${PROVIDERS_TVM} ${PROVIDERS_XNNPACK} ${PROVIDERS_AZURE} onnxruntime_optimizer onnxruntime_providers onnxruntime_util - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_util @@ -723,12 +746,6 @@ if(onnxruntime_USE_AZURE) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_azure) endif() -if(WIN32) - if (onnxruntime_USE_TVM) - list(APPEND disabled_warnings ${DISABLED_WARNINGS_FOR_TVM}) - endif() -endif() - file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) @@ -743,9 +760,7 @@ if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() - target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) - onnxruntime_add_include_to_target(onnxruntime_test_utils nsync::nsync_cpp) endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) @@ -779,9 +794,7 @@ if(NOT IOS) target_compile_options(onnx_test_runner_common PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") else() - target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) - onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) endif() if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous @@ -833,9 +846,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) list(APPEND all_tests ${onnxruntime_test_training_api_src}) endif() -if (onnxruntime_USE_TVM) - list(APPEND all_tests ${onnxruntime_test_tvm_src}) -endif() if (onnxruntime_USE_OPENVINO) list(APPEND all_tests ${onnxruntime_test_openvino_src}) @@ -1067,15 +1077,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) COMMAND ${CMAKE_COMMAND} -E copy ${DNNL_DLL_PATH} $ ) endif() - if(WIN32) - if (onnxruntime_USE_TVM) - add_custom_command( - TARGET ${test_data_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - ) - endif() - endif() - if(WIN32) set(wide_get_opt_src_dir ${TEST_SRC_DIR}/win_getopt/wide) onnxruntime_add_static_library(win_getopt_wide ${wide_get_opt_src_dir}/getopt.cc ${wide_get_opt_src_dir}/include/getopt.h) @@ -1117,12 +1118,6 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") - endif() - endif() - install(TARGETS onnx_test_runner ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1146,7 +1141,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/gelu.cc ${BENCHMARK_DIR}/activation.cc ${BENCHMARK_DIR}/quantize.cc - ${BENCHMARK_DIR}/reduceminmax.cc) + ${BENCHMARK_DIR}/reduceminmax.cc + ${BENCHMARK_DIR}/layer_normalization.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) @@ -1183,7 +1179,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) # "Global initializer calls a non-constexpr function." BENCHMARK_CAPTURE macro needs this. target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26426) else() - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_mlas_benchmark PRIVATE ${CMAKE_DL_LIBS}) endif() if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo) @@ -1256,7 +1252,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) if(NOT WIN32) - list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) endif() @@ -1276,11 +1271,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") - endif() - endif() endif() @@ -1324,7 +1314,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) # test inference using shared lib set(onnxruntime_shared_lib_test_LIBS onnxruntime_mocked_allocator onnxruntime_test_utils onnxruntime_common onnx_proto) if(NOT WIN32) - list(APPEND onnxruntime_shared_lib_test_LIBS nsync::nsync_cpp) if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_providers_snpe) endif() @@ -1473,7 +1462,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) endif() if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_mlas_test PRIVATE ${CMAKE_DL_LIBS}) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) @@ -1659,9 +1648,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ${ONNXRUNTIME_CUSTOM_OP_REGISTRATION_TEST_SRC_DIR}/test_registercustomops.cc) set(onnxruntime_customopregistration_test_LIBS custom_op_library onnxruntime_common onnxruntime_test_utils) - if (NOT WIN32) - list(APPEND onnxruntime_customopregistration_test_LIBS nsync::nsync_cpp) - endif() + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND onnxruntime_customopregistration_test_LIBS cpuinfo) endif() @@ -1669,7 +1656,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) + list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto) endif() AddTest(DYN TARGET onnxruntime_customopregistration_test @@ -1788,11 +1775,11 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils) if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) + list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto) endif() if(NOT WIN32) - list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS}) + list(APPEND onnxruntime_logging_apis_test_LIBS ${CMAKE_DL_LIBS}) endif() AddTest(DYN @@ -1868,4 +1855,13 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD endif() endif() +if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN) + AddTest(TARGET onnxruntime_webgpu_external_dawn_test + SOURCES ${onnxruntime_webgpu_external_dawn_test_SRC} + LIBS dawn::dawn_native ${onnxruntime_test_providers_libs} + DEPENDS ${all_dependencies} + ) + onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers) +endif() + include(onnxruntime_fuzz_test.cmake) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 3a1576065205f..66268cefac9ef 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -97,7 +97,6 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable) if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) bundle_static_library(onnxruntime_webassembly - nsync::nsync_cpp ${PROTOBUF_LIB} onnx onnx_proto @@ -175,7 +174,6 @@ else() endif() target_link_libraries(onnxruntime_webassembly PRIVATE - nsync::nsync_cpp ${PROTOBUF_LIB} onnx onnx_proto @@ -194,9 +192,7 @@ else() onnxruntime_util re2::re2 ) - - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") - + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") @@ -217,10 +213,114 @@ else() set(EXPORTED_FUNCTIONS "_malloc,_free") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(MAXIMUM_MEMORY "17179869184") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s MEMORY64=1" + ) + string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") + set(SMEMORY_FLAG "-sMEMORY64") + + target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + if (onnxruntime_USE_EXTENSIONS) + target_compile_options(ortcustomops PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(ocos_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + endif() + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + ) + else () + set(MAXIMUM_MEMORY "4294967296") + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + ) + endif () + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" - "SHELL:-s MAXIMUM_MEMORY=4294967296" + "SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}" "SHELL:-s EXIT_RUNTIME=0" "SHELL:-s ALLOW_MEMORY_GROWTH=1" "SHELL:-s MODULARIZE=1" @@ -233,6 +333,41 @@ else() --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\ +OrtRunWithBinding:_ppppp,\ +OrtGetTensorData:_ppppp,\ +OrtCreateTensor:p_pppp_,\ +OrtCreateSession:pppp,\ +OrtReleaseSession:_p,\ +OrtGetInputOutputCount:_ppp,\ +OrtCreateSessionOptions:pp__p_ppppp,\ +OrtReleaseSessionOptions:_p,\ +OrtAppendExecutionProvider:_pp,\ +OrtAddSessionConfigEntry:_ppp,\ +OrtGetInputName:ppp,\ +OrtGetOutputName:ppp,\ +OrtCreateRunOptions:ppp_p,\ +OrtReleaseRunOptions:_p,\ +OrtReleaseTensor:_p,\ +OrtFree:_p,\ +OrtCreateBinding:_p,\ +OrtBindInput:_ppp,\ +OrtBindOutput:_ppp_,\ +OrtClearBoundOutputs:_p,\ +OrtReleaseBinding:_p,\ +OrtGetLastError:_pp,\ +JsepOutput:pp_p,\ +JsepGetNodeName:pp,\ +JsepOutput:pp_p,\ +jsepCopy:_pp_,\ +jsepCopyAsync:_pp_,\ +jsepDownload:_pp_") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" + "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" + ) + endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) @@ -245,6 +380,8 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() @@ -281,7 +418,9 @@ else() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + endif() if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) diff --git a/cmake/patches/composable_kernel/Add_gfx12x_support.patch b/cmake/patches/composable_kernel/Add_gfx12x_support.patch new file mode 100644 index 0000000000000..ef529184d2ed8 --- /dev/null +++ b/cmake/patches/composable_kernel/Add_gfx12x_support.patch @@ -0,0 +1,2280 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index bc326c8b5..db5ad5052 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -117,7 +117,7 @@ else() + add_definitions(-DPROFILER_ONLY) + set(GPU_TARGETS "" CACHE STRING "" FORCE) + if(GPU_TARGETS) +- message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") ++ message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") + endif() + if(GPU_ARCH MATCHES "gfx90") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") +@@ -127,8 +127,10 @@ else() + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") + elseif(GPU_ARCH MATCHES "gfx11") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") ++ elseif(GPU_ARCH MATCHES "gfx12") ++ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") + else() +- message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") ++ message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") + endif() + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + endif() +diff --git a/Jenkinsfile b/Jenkinsfile +index 75800bfc9..b72e2ca4e 100644 +--- a/Jenkinsfile ++++ b/Jenkinsfile +@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ + + def variant = env.STAGE_NAME + def retimage ++ + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + try { + (retimage, image) = getDockerImage(conf) +@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM + + pipeline { + agent none +- triggers { +- parameterizedCron(CRON_SETTINGS) +- } + options { + parallelsAlwaysFailFast() + } +diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake +index 8654170b3..42070051b 100644 +--- a/cmake/EnableCompilerWarnings.cmake ++++ b/cmake/EnableCompilerWarnings.cmake +@@ -66,7 +66,7 @@ else() + -Wunreachable-code + -Wunused + -Wno-reserved-identifier +- -Werror ++ -Werror + -Wno-option-ignored + -Wsign-compare + -Wno-extra-semi-stmt +diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp +index 8c52e4f7d..f8afe8d6d 100644 +--- a/example/01_gemm/gemm_wmma_fp16.cpp ++++ b/example/01_gemm/gemm_wmma_fp16.cpp +@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa + + // clang-format off + using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle +- < ALayout, +- BLayout, +- CLayout, +- ADataType, ++ < ALayout, ++ BLayout, ++ CLayout, ++ ADataType, + BDataType, +- CDataType, +- AccDataType, +- CShuffleDataType, +- AElementOp, +- BElementOp, +- CElementOp, +- GemmDefault, ++ CDataType, ++ AccDataType, ++ CShuffleDataType, ++ AElementOp, ++ BElementOp, ++ CElementOp, ++ GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock +- 8, // K1 ++ 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave +- S<4, 32, 1>, +- S<1, 0, 2>, +- S<1, 0, 2>, +- 2, +- 8, +- 8, +- true, +- S<4, 32, 1>, +- S<1, 0, 2>, +- S<1, 0, 2>, +- 2, +- 8, +- 8, +- true, ++ S<4, 32, 1>, ++ S<1, 0, 2>, ++ S<1, 0, 2>, ++ 2, ++ 2, ++ 2, ++ true, ++ S<4, 32, 1>, ++ S<1, 0, 2>, ++ S<1, 0, 2>, ++ 2, ++ 2, ++ 2, ++ true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store +- S<1, 32, 1, 4>, ++ S<1, 32, 1, 4>, + 8>; + // clang-format on + +diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc +index b04e4e53a..cb15186c3 100644 +--- a/example/01_gemm/run_gemm_example.inc ++++ b/example/01_gemm/run_gemm_example.inc +@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: +- ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); ++ ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: +diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt +index ab19f819e..be47665a2 100644 +--- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt ++++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt +@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) + set(target 1) + endif() +-endforeach() +\ No newline at end of file ++endforeach() +diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +index 2bbf430c4..f556be887 100644 +--- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp ++++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = + 2, + 4, + 4, +- true, ++ false, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, +- true, ++ false, + 1, + 1, + S<1, 64, 1, 2>, +diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +index 4c92c5497..fac19f8b5 100644 +--- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp ++++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial + #define CK_MHA_USE_WAVE_1 + #define CK_MHA_USE_WAVE_2 + #define CK_MHA_USE_WAVE_4 +-#define CK_MHA_USE_WAVE_8 ++//#define CK_MHA_USE_WAVE_8 + using DeviceMHAFactory = + std::tuple< + #ifdef CK_MHA_USE_WAVE_1 +@@ -277,10 +277,10 @@ using DeviceMHAFactory = + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, +- MaskingSpec>, ++ MaskingSpec> + #endif + #ifdef CK_MHA_USE_WAVE_8 +- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ++ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, +diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +index 8e037272b..d463cc871 100644 +--- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp ++++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial + #define CK_MHA_USE_WAVE_1 + #define CK_MHA_USE_WAVE_2 + #define CK_MHA_USE_WAVE_4 +-#define CK_MHA_USE_WAVE_8 ++//#define CK_MHA_USE_WAVE_8 + using DeviceMHAFactory = + std::tuple< + #ifdef CK_MHA_USE_WAVE_1 +@@ -277,10 +277,10 @@ using DeviceMHAFactory = + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, +- MaskingSpec>, ++ MaskingSpec> + #endif + #ifdef CK_MHA_USE_WAVE_8 +- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ++ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, +diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt +index 5465adb77..7534bff3b 100644 +--- a/example/CMakeLists.txt ++++ b/example/CMakeLists.txt +@@ -60,7 +60,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() +@@ -134,7 +134,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() +diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp +index 55f562061..69a7abf62 100644 +--- a/include/ck/ck.hpp ++++ b/include/ck/ck.hpp +@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) + #define __gfx11__ + #endif ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif + + // buffer resource + #ifndef __HIP_DEVICE_COMPILE__ // for host code +@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 + #elif defined(__gfx103__) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +-#elif defined(__gfx11__) ++#elif defined(__gfx11__) || defined(__gfx12__) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 + #endif + +@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_USE_AMD_V_FMAC_F32 + #define CK_USE_AMD_V_DOT2_F32_F16 + #define CK_USE_AMD_V_DOT4_I32_I8 +-#elif defined(__gfx11__) ++#elif defined(__gfx11__) || defined(__gfx12__) + #define CK_USE_AMD_V_FMAC_F32 + #define CK_USE_AMD_V_DOT2_F32_F16 + #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 +@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_USE_AMD_MFMA_GFX940 + #endif + +-// WMMA instruction +-#ifndef __HIP_DEVICE_COMPILE__ // for host code +-#define CK_USE_AMD_WMMA +-#elif defined(__gfx11__) // for GPU code +-#define CK_USE_AMD_WMMA +-#endif +- + // buffer load + #define CK_USE_AMD_BUFFER_LOAD 1 + +diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp +index 116bb3ea0..83af2efe8 100644 +--- a/include/ck/host_utility/device_prop.hpp ++++ b/include/ck/host_utility/device_prop.hpp +@@ -84,4 +84,9 @@ inline bool is_gfx11_supported() + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; + } + ++inline bool is_gfx12_supported() ++{ ++ return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; ++} ++ + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +index f8ee283c6..7eb7d42eb 100644 +--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +@@ -13,6 +13,504 @@ + + namespace ck { + ++#ifdef __gfx12__ ++template ++/* Option: Read from LDS, big buffer hold all threads required data ++ * Source ++ * A: K0PerBlock x MPerBlock x K1 ++ * B: K0PerBlock x NPerBlock x K1 ++ * Destination ++ * C, non-transpose ++ * thread level: MRepeat x NRepeat x MAccVgprs ++ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs ++ * KPACK == WMMA_K = 16 ++ * ++ * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) ++ * Source: ++ * A(if skip LDS): MRepeat x KPack ++ * B(if skip LDS): NRepeat x KPack ++ * Destination ++ * C, non-transpose ++ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs ++ */ ++struct BlockwiseGemmWMMA ++{ ++ static constexpr auto I0 = Number<0>{}; ++ static constexpr auto I1 = Number<1>{}; ++ static constexpr auto I2 = Number<2>{}; ++ static constexpr auto I3 = Number<3>{}; ++ static constexpr auto I4 = Number<4>{}; ++ static constexpr auto I5 = Number<5>{}; ++ static constexpr auto WmmaK = Number<16>{}; ++ ++ using ThisThreadBlock = ThisThreadBlock; ++ ++ // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. ++ static constexpr index_t WaveSize = 32; ++ ++ // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer ++ // When not use LDS, each Row read half of whole data from source buffer, exchange the data via ++ // permutation ++ static constexpr index_t A_KRow = 2; ++ static constexpr index_t B_KRow = 2; ++ ++ static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); ++ static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); ++ ++ static constexpr auto wmma_gemm = ++ WmmaGemm{}; ++ ++ static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); ++ static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); ++ ++ StaticBufferTupleOfVector ++ c_thread_buf_; ++ ++ __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } ++ ++ __device__ static auto GetWaveIdx() ++ { ++ const index_t thread_id = ThisThreadBlock::GetThreadId(); ++ ++ constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), ++ make_tuple(Sequence<0, 1, 2>{}), ++ make_tuple(Sequence<0>{})); ++ ++ return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); ++ } ++ ++ // Default, Block buffer in LDS, thread level offset enabled ++ __device__ static auto CalculateAThreadOriginDataIndex() ++ { ++ if constexpr(AEnableLds) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ const auto waveId_m = wave_idx[I0]; ++ const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); ++ ++ // |KRepeat |MRepeat|MWave |KRow |MLane |KPack ++ return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); ++ } ++ else ++ { ++ return make_tuple(0, 0, 0, 0, 0, 0); ++ } ++ } ++ ++ __device__ static auto CalculateBThreadOriginDataIndex() ++ { ++ if constexpr(BEnableLds) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ const auto waveId_n = wave_idx[I1]; ++ const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); ++ ++ // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack ++ return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); ++ } ++ else ++ { ++ return make_tuple(0, 0, 0, 0, 0, 0); ++ } ++ } ++ ++ template ++ __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ ++ const auto waveId_m = wave_idx[I0]; ++ const auto waveId_n = wave_idx[I1]; ++ ++ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); ++ ++ constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), ++ make_tuple(Sequence<0>{}), ++ make_tuple(Sequence<0, 1, 2>{})); ++ ++ constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), ++ make_tuple(Sequence<0>{}), ++ make_tuple(Sequence<0, 1, 2>{})); ++ ++ const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( ++ make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; ++ const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( ++ make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; ++ ++ return make_tuple(c_thread_m, c_thread_n); ++ } ++ ++ template ++ __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ ++ const auto waveId_m = wave_idx[I0]; ++ const auto waveId_n = wave_idx[I1]; ++ ++ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); ++ ++ return make_tuple( ++ Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); ++ } ++ ++ using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); ++ __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), ++ Tuple6 b_origin = CalculateBThreadOriginDataIndex()) ++ : a_thread_copy_(a_origin), b_thread_copy_(b_origin) ++ { ++ static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), ++ "wrong! Desc should be known at compile-time"); ++ ++ static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, ++ "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); ++ ++ static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && ++ NPerBlock % (NPerWMMA * NRepeat) == 0, ++ "wrong!"); ++ } ++ ++ // transposed WMMA output C' = B' * A' ++ __host__ __device__ static constexpr auto ++ GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() ++ { ++ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = ++ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); ++ ++ constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; ++ ++ return make_naive_tensor_descriptor_packed( ++ // |MRepeat |MWave |MSubGroup |NRepeat |NWave ++ // |NThreadPerSubGroup |MAccVgprs ++ make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); ++ } ++ ++ // Thread level, register decriptor. Vector-write ++ __host__ __device__ static constexpr auto ++ GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() ++ { ++ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = ++ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); ++ ++ constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; ++ constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; ++ return make_naive_tensor_descriptor( ++ // |MRepeat |MWave |MSubGroup |NRepeat |NWave ++ // |NThreadPerSubGroup |MAccVgprs ++ make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), ++ make_tuple(Number{} * MAccVgprs * AccStride, ++ Number{} * MAccVgprs * AccStride, ++ Number{} * MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ AccStride)); ++ } ++ ++ template ++ __host__ __device__ static constexpr auto ++ MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ const CGridDesc_M_N& c_grid_desc_m_n) ++ { ++ const auto M = c_grid_desc_m_n.GetLength(I0); ++ const auto N = c_grid_desc_m_n.GetLength(I1); ++ ++ const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = ++ transform_tensor_descriptor( ++ c_grid_desc_m_n, ++ make_tuple( ++ make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), ++ make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), ++ make_tuple(Sequence<0>{}, Sequence<1>{}), ++ make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); ++ } ++ ++ // transposed WMMA output C' = B' * A' ++ __host__ __device__ static constexpr auto ++ GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() ++ { ++ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = ++ make_naive_tensor_descriptor_packed(make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( ++ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); ++ } ++ ++ // Provide dimension size ++ __host__ __device__ static constexpr auto ++ GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() ++ { ++ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = ++ make_naive_tensor_descriptor_packed(make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); ++ } ++ ++ // Describe how data allocated in thread copy src buffer ++ // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma ++ static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; ++ static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; ++ ++ template ++ __device__ void Run(const ABlockBuffer& a_block_buf, ++ const BBlockBuffer& b_block_buf, ++ CThreadBuffer& c_thread_buf) const ++ { ++ auto a_thread_buf = make_static_buffer( ++ a_thread_desc_.GetElementSpaceSize()); ++ auto b_thread_buf = make_static_buffer( ++ b_thread_desc_.GetElementSpaceSize()); ++ ++ static_assert(KPack % (A_K1 * A_KRow) == 0, ""); ++ static_assert(KPack % (B_K1 * B_KRow) == 0, ""); ++ ++ // basic intrinsic to determine loopover direction ++ if constexpr(MRepeat < NRepeat) ++ { ++ static_for<0, KPerBlock / KPack, 1>{}( ++ [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... ++ static_for<0, MRepeat, 1>{}([&](auto m0) { ++ // read A ++ a_thread_copy_.Run( ++ a_block_desc_k0_m0_m1_m2_k1, ++ make_tuple(Number{}, m0, I0, I0, I0, I0), ++ a_block_buf, ++ a_thread_desc_, ++ make_tuple(I0, m0, I0, I0, I0, I0), ++ a_thread_buf); ++ ++ static_for<0, NRepeat, 1>{}([&](auto n0) { ++ // read B ++ b_thread_copy_.Run( ++ b_block_desc_k0_n0_n1_n2_k1, ++ make_tuple(Number{}, n0, I0, I0, I0, I0), ++ b_block_buf, ++ b_thread_desc_, ++ make_tuple(I0, n0, I0, I0, I0, I0), ++ b_thread_buf); ++ ++ vector_type a_thread_vec; ++ vector_type b_thread_vec; ++ ++ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { ++ a_thread_vec.template AsType()(i) = ++ a_thread_buf[Number{}]; ++ }); ++ ++ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { ++ b_thread_vec.template AsType()(i) = ++ b_thread_buf[Number{}]; ++ }); ++ ++ using wmma_input_type_a = ++ typename vector_type::type; ++ using wmma_input_type_b = ++ typename vector_type::type; ++ ++ constexpr index_t c_offset = ++ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); ++ ++ wmma_gemm.template Run( ++ a_thread_vec.template AsType(), ++ b_thread_vec.template AsType(), ++ c_thread_buf.GetVectorTypeReference(Number{})); ++ }); ++ }); ++ }); ++ } ++ else ++ { ++ static_for<0, NRepeat, 1>{}([&](auto n0) { ++ static_for<0, MRepeat, 1>{}([&](auto m0) { ++ static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of ++ // k=0,kpack*1, .. ++ // read B ++ b_thread_copy_.Run( ++ b_block_desc_k0_n0_n1_n2_k1, ++ make_tuple(Number{}, n0, I0, I0, I0, I0), ++ b_block_buf, ++ b_thread_desc_, ++ make_tuple(I0, n0, I0, I0, I0, I0), ++ b_thread_buf); ++ // read A ++ a_thread_copy_.Run( ++ a_block_desc_k0_m0_m1_m2_k1, ++ make_tuple(Number{}, m0, I0, I0, I0, I0), ++ a_block_buf, ++ a_thread_desc_, ++ make_tuple(I0, m0, I0, I0, I0, I0), ++ a_thread_buf); ++ ++ vector_type a_thread_vec; ++ vector_type b_thread_vec; ++ ++ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { ++ a_thread_vec.template AsType()(i) = ++ a_thread_buf[Number{}]; ++ }); ++ ++ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { ++ b_thread_vec.template AsType()(i) = ++ b_thread_buf[Number{}]; ++ }); ++ ++ using wmma_input_type_a = ++ typename vector_type::type; ++ using wmma_input_type_b = ++ typename vector_type::type; ++ ++ constexpr index_t c_offset = ++ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); ++ ++ wmma_gemm.template Run( ++ a_thread_vec.template AsType(), ++ b_thread_vec.template AsType(), ++ c_thread_buf.GetVectorTypeReference(Number{})); ++ }); ++ }); ++ }); ++ } ++ } ++ ++ protected: ++ static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( ++ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), ++ make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number<1>{})); ++ ++ static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( ++ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), ++ make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number<1>{})); ++ ++ // C[M, N, NumRegWMMA] ++ static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( ++ make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); ++ ++ template ++ struct AThreadCopySelector; ++ ++ template <> ++ struct AThreadCopySelector ++ { ++ using type = ++ ThreadwiseTensorSliceTransfer_v4, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ A_K1, ++ A_K1>; ++ }; ++ ++ template <> ++ struct AThreadCopySelector ++ { ++ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< ++ FloatA, ++ FloatA, ++ decltype(a_block_desc_k0_m0_m1_m2_k1), ++ decltype(a_thread_desc_), ++ tensor_operation::element_wise::PassThrough, ++ Sequence, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ A_K1, ++ false>; ++ }; ++ ++ template ++ struct BThreadCopySelector; ++ ++ template <> ++ struct BThreadCopySelector ++ { ++ using type = ++ ThreadwiseTensorSliceTransfer_v4, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ B_K1, ++ B_K1>; ++ }; ++ ++ template <> ++ struct BThreadCopySelector ++ { ++ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< ++ FloatB, ++ FloatB, ++ decltype(b_block_desc_k0_n0_n1_n2_k1), ++ decltype(b_thread_desc_), ++ tensor_operation::element_wise::PassThrough, ++ Sequence, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ B_K1, ++ false>; ++ }; ++ ++ typename AThreadCopySelector::type a_thread_copy_; ++ typename BThreadCopySelector::type b_thread_copy_; ++}; ++#else + template ::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; + }; ++#endif + + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +index e5e6245cb..1f7d50429 100644 +--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp ++++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("s_barrier" ::); ++#endif + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +index a15759559..ab3f3856a 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + +- static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; +- static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; ++ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; ++ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; ++ ++ static constexpr auto AEnableLds_auto = ++ (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; ++ static constexpr auto BEnableLds_auto = ++ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; +@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + + static bool IsSupportedArgument(const Argument& arg) + { +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + } + else + { +- if(!(arg.a_kz_stride_ == 1 && +- arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) ++ if(!(arg.a_kz_stride_ == 1)) + { +- printf("DeviceOp: Vector Access A-k check failure\n"); +- return false; ++ index_t LastK = ++ AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); ++ if(LastK % ABlockTransferSrcScalarPerVector == 0) ++ { ++ printf("DeviceOp: Vector Access A-k check failure\n"); ++ return false; ++ } + } + } + +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +index 8fd14afc0..1b487502f 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +@@ -70,8 +70,9 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +index 9d5b74be6..017d28641 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +@@ -601,9 +601,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle + return false; + } + +- if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && +- ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && +- std::is_same::value) ++ if(!ck::is_lds_direct_load_supported() && std::is_same::value) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +index b84e18130..1edae33be 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl + { + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || +- ck::is_gfx11_supported())) ++ ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +index bf96324d0..553143e28 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || + is_same_v)) +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +index b1784b385..eb0fb55f5 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +index 93ab8a7e1..a7cc546f5 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; + +- static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); +- static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); +- static constexpr auto WmmaK = K1 == 16 ? 32 : 16; +- +- static constexpr auto AEnableLds_auto = +- (NWaves == 1 && is_same::value) ? false : true; ++ static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); ++ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); ++ static constexpr auto WmmaK = K1 == 16 ? 32 : 16; ++ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; ++ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; ++ ++ static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && ++ is_same::value) ++ ? false ++ : true; + static constexpr auto BEnableLds_auto = +- (MWaves == 1 && is_same::value) ? false : true; ++ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && ++ is_same::value) ++ ? false ++ : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; +@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || + is_same_v)) +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +index 6f74838fb..6bb5d431c 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + static bool IsSupportedArgument(const Argument& arg) + { + // check device +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +index bd264a3c8..7047e1bda 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +@@ -48,8 +48,9 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ ++ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +index 211185dfb..5738be0fb 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + static bool IsSupportedArgument(const Argument& arg) + { + // check device +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +index 7cfbd8a8f..5d5a9de7d 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +@@ -90,8 +90,9 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ ++ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || +- ck::is_gfx103_supported() || ck::is_gfx11_supported())) ++ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +index 6a4d97d7d..c65370b51 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +@@ -107,7 +107,7 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { + #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx11__)) ++ defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +index ac392cddc..060a16d1e 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +@@ -39,8 +39,9 @@ __global__ void + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ ++ defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); +@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +index 4e14ed3a5..cc88c1a10 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +@@ -60,7 +60,7 @@ __global__ void + bool input_permute, + bool output_permute) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off + // *************************************************** +@@ -165,6 +165,7 @@ __global__ void + ignore = O; + ignore = G0; + ignore = G1; ++ ignore = alpha; + ignore = input_permute; + ignore = output_permute; + #endif // end of if (defined(__gfx11__)) +@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma + + static bool IsSupportedArgument(const RawArg& arg) + { +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +index 16717ff81..1754e07e6 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 +- constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + B0BlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 +- constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); +- constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); ++ constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); ++ constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_LRow = I2; ++#else + constexpr auto B_LRow = I1; ++#endif + return transform_tensor_descriptor( + B1BlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +index 499eb7eb0..21dac6f9e 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +@@ -50,7 +50,7 @@ __global__ void + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, +@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +index 82d010a99..fdda649ef 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +@@ -54,7 +54,7 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -147,7 +147,7 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // printf("entry kernel launch"); + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + +@@ -237,7 +237,7 @@ __global__ void + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + + GridwiseOp::template Run(p_a_grid, +@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma + } + else + { ++ constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma + } + else + { ++ constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { +- constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); +- constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); +- + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, +- Number{}, ++ Number{}, + I1, +- Number{})); ++ Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } +@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + +- const auto MBlock = M / MPerBlock; +- const auto NBlock = N / NPerBlock; ++ const auto MBlock = M / MPerBlock; ++ const auto NBlock = N / NPerBlock; ++ + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +index 8e4117593..4458b9356 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +@@ -45,7 +45,7 @@ __global__ void + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, +@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma + } + else + { ++ constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma + } + else + { ++ ++ constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif ++ + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma + c_grid_desc_m_n); + } + +- using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = +- remove_cvref_t; +- using DefaultBlock2CTileMap = +- remove_cvref_t; +- + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment +@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma + b_block_space_size_aligned * sizeof(BDataType)); + }; + ++ using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = ++ remove_cvref_t; ++ using DefaultBlock2CTileMap = ++ remove_cvref_t; ++ + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +index 6772524e0..174074990 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +@@ -35,8 +35,9 @@ __global__ void + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, +diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +index bcce930fc..d7a6a3624 100644 +--- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp ++++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic + ElementwiseOperation element_op_; + }; + +-// Specilized for WMMA ++// Specilized for WMMA-Navi3 + // A single Wave32 is composed by double row + // Data exchange allowed between these two rows + // This RowLane Dst buf will be filled from two Src buf +@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow + ElementwiseOperation element_op_{}; + }; + ++// Specilized for WMMA-Navi4 ++template ::type = false> ++struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow ++{ ++ static constexpr index_t nDim = SliceLengths::Size(); ++ ++ using Index = MultiIndex; ++ ++ __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) ++ { ++ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), ++ "wrong! Desc need to known at compile-time"); ++ ++ static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, ++ "wrong! Not divisible"); ++ ignore = src_idx; ++ } ++ ++ template ++ __device__ void Run(const SrcDesc&, ++ const SrcSliceOriginIdx&, ++ const SrcBuffer& src_buf, ++ const DstDesc&, ++ const DstSliceOriginIdx&, ++ DstBuffer& dst_buf) const ++ { ++ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), ++ "wrong! Desc need to known at compile-time"); ++ ++ static_assert(is_known_at_compile_time>::value && ++ is_known_at_compile_time>::value, ++ "wrong! SliceOrigin need to known at compile-time"); ++ ++ static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), ++ "wrong! Buffer need to be StaticBuffer"); ++ ++ // SrcDesc and src_slice_origin_idx are known at compile-time ++ constexpr auto src_desc = remove_cvref_t{}; ++ constexpr auto dst_desc = remove_cvref_t{}; ++ constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); ++ constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); ++ ++ // scalar per access on each dim ++ constexpr auto dst_scalar_per_access = generate_sequence( ++ detail::lambda_scalar_per_access{}, Number{}); ++ ++ constexpr auto dst_scalar_step_in_vector = ++ generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); ++ ++ using SpaceFillingCurve = SpaceFillingCurve>; ++ ++ static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, ++ "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); ++ ++ constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); ++ ++ static_for<0, num_access, 1>{}([&](auto idx_1d) { ++ constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); ++ ++ // copy data from src_buf into dst_vector ++ static_for<0, DstScalarPerVector, 1>{}([&](auto i) { ++ // src_desc error, non constexpr, caused by merge transform ++ constexpr index_t src_offset = src_desc.CalculateOffset( ++ src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); ++ ++ constexpr index_t dst_offset = dst_desc.CalculateOffset( ++ dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); ++ ++ SrcData v_this_row; ++ // int type temp value due to intrinsic requirement ++ int temp = 0; ++ ++ // apply element-wise operation ++ element_op_(v_this_row, src_buf[Number{}]); ++ ++ // apply intra-row permute. ++ if constexpr(IntraRowSwizzlePerm) ++ { ++ temp = __builtin_amdgcn_permlane16( ++ temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); ++ v_this_row = type_convert_sp(temp); ++ } ++ ++ // apply type convert ++ dst_buf(Number{}) = type_convert_sp(v_this_row); ++ }); ++ }); ++ } ++ ElementwiseOperation element_op_{}; ++}; ++ + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +index 565195f53..9a9ebf559 100644 +--- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp ++++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +@@ -11,12 +11,17 @@ namespace ck { + + enum struct WmmaInstr + { ++ // gfx11 + wmma_f32_16x16x16_f16 = 0, + wmma_f32_16x16x16_bf16, + wmma_f16_16x16x16_f16, + wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu8, +- wmma_i32_16x16x16_iu4 ++ wmma_i32_16x16x16_iu4, ++ // gfx12 ++ wmma_f32_16x16x16_f16_gfx12, ++ wmma_f32_16x16x16_bf16_gfx12, ++ wmma_i32_16x16x16_iu8_gfx12, + }; + + /* +@@ -279,6 +284,122 @@ struct wmma_type ++struct wmma_type> ++{ ++ // Absolute fixing property ++ // * Data Pixel ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ // static constexpr index_t acc_data_size = 4; ++ // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // * Fixed in Navi3x, Will be wave mode dependent on Navi4x ++ // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; ++ // * num_acc_vgprs_per_wave alone M direction ++ // * num_subgroups alone M direction ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); ++ } ++ } ++}; ++ ++template ++struct wmma_type> ++{ ++ // Absolute fixing property ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); ++ } ++ } ++}; ++ ++template ++struct wmma_type> ++{ ++ // Absolute fixing property ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( ++ a, b, reg_c); ++ } ++ } ++}; ++ + template + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; ++#else + return WmmaInstr::wmma_f32_16x16x16_f16; ++#endif + } + + template <> + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; ++#else + return WmmaInstr::wmma_f32_16x16x16_bf16; ++#endif + } + + template <> +@@ -320,8 +449,13 @@ struct WmmaSelector + template <> + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; ++#else + return WmmaInstr::wmma_i32_16x16x16_iu8; ++#endif + } ++ + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + static constexpr auto GetWmma() +@@ -502,6 +636,9 @@ struct WmmaGemm + + __device__ static auto GetSubGroupId() + { ++ static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == ++ wmma_instr.wave_size, ++ ""); + return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; + } + +@@ -516,12 +653,20 @@ struct WmmaGemm + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { ++#ifdef __gfx12__ ++ return GetLaneIdUnderSubGroup(); ++#else + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); ++#endif + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { ++#ifdef __gfx12__ ++ return GetLaneIdUnderSubGroup(); ++#else + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); ++#endif + } + + __device__ static CIndex GetBeginOfThreadBlk() +diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp +index 1bb0140f3..322a0f94b 100644 +--- a/include/ck/utility/amd_wmma.hpp ++++ b/include/ck/utility/amd_wmma.hpp +@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> + } + }; + ++// gfx12 ++/********************************WAVE32 MODE***********************************************/ ++ ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif ++ ++// src: fp16, dst: fp32 ++template ++struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; ++ ++template <> ++struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> ++{ ++ template ++ __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) ++ { ++ // * Inline assembly need to elimate the duplicated data load, compiler won't help you ++ // delete them. ++ // amd_assembly_wmma_f32_16x16x16_f16_w32( ++ // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( ++ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ ++// src: bf16, dst: fp32 ++template ++struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; ++ ++template <> ++struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> ++{ ++ template ++ __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) ++ { ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( ++ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ ++// src: iu8, dst: i32 ++template ++struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; ++ ++template ++struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> ++{ ++ template ++ __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) ++ { ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( ++ neg_a, ++ bit_cast(reg_a), ++ neg_b, ++ bit_cast(reg_b), ++ reg_c.template AsType()[Number<0>{}], ++ clamp); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ + } // namespace ck + #endif +diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp +index 93a1edefb..4df14c621 100644 +--- a/include/ck/utility/data_type.hpp ++++ b/include/ck/utility/data_type.hpp +@@ -203,7 +203,7 @@ struct vector_type + } + }; + +-int static err = 0; ++__device__ int static err = 0; + template + struct vector_type + { +diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp +index 4fe5e3950..d6b6eac26 100644 +--- a/include/ck/utility/synchronization.hpp ++++ b/include/ck/utility/synchronization.hpp +@@ -10,12 +10,20 @@ namespace ck { + __device__ void block_sync_lds() + { + #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); ++#endif + #else + __syncthreads(); + #endif +@@ -23,11 +31,20 @@ __device__ void block_sync_lds() + + __device__ void block_sync_lds_direct_load() + { ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_vmcnt 0x0 \n \ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt vmcnt(0) \n \ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + } + + __device__ void s_nop() +diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp +index 601aad19b..9dc2b072a 100644 +--- a/include/ck_tile/core/config.hpp ++++ b/include/ck_tile/core/config.hpp +@@ -17,6 +17,9 @@ + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) + #define __gfx11__ + #endif ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif + + #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS + #include "hip/hip_runtime.h" +@@ -155,7 +158,7 @@ + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 + #elif defined(__gfx103__) // for GPU code + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +-#elif defined(__gfx11__) // for GPU code ++#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 + #endif + +diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt +index 8c5f36d2e..89c9d6dc6 100644 +--- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt ++++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt +@@ -52,7 +52,7 @@ function(add_instance_library INSTANCE_NAME) + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +@@ -149,7 +149,7 @@ FOREACH(subdir_path ${dir_list}) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) ++ if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() +@@ -157,11 +157,11 @@ FOREACH(subdir_path ${dir_list}) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) ++ if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) ++ if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() +diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt +index 1cfcbfff6..a9557a9b9 100644 +--- a/profiler/src/CMakeLists.txt ++++ b/profiler/src/CMakeLists.txt +@@ -58,7 +58,7 @@ if(GPU_TARGETS MATCHES "gfx9") + + endif() + +-if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") ++if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() +@@ -133,7 +133,7 @@ if(GPU_TARGETS MATCHES "gfx9") + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + endif() + +-if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") ++if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() +diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt +index 25c63ac7f..2a7c52b58 100644 +--- a/test/CMakeLists.txt ++++ b/test/CMakeLists.txt +@@ -53,7 +53,7 @@ function(add_test_executable TEST_NAME) + endif() + endforeach() + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +@@ -118,7 +118,7 @@ function(add_gtest_executable TEST_NAME) + endif() + endforeach() + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +index 1c8082645..21f49ec0f 100644 +--- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp ++++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test + } + } + +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + // on gfx11 only support for 3d is implemented + if constexpr(NDimSpatial{} != 3) +diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp +index 49782bce6..d9ec94771 100644 +--- a/test/wmma_op/wmma_op_util.hpp ++++ b/test/wmma_op/wmma_op_util.hpp +@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; + } + ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + + for(int ele = 0; ele < 16; ++ele) + { +@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) + a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; + } + ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + + // sync threads, similar to mma_sync + // __syncthreads(); diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch deleted file mode 100644 index d696d386452e8..0000000000000 --- a/cmake/patches/dawn/dawn.patch +++ /dev/null @@ -1,66 +0,0 @@ -diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt -index 9c0bd6fa4e..bf8a57aeac 100644 ---- a/src/dawn/native/CMakeLists.txt -+++ b/src/dawn/native/CMakeLists.txt -@@ -857,6 +857,11 @@ if (DAWN_ENABLE_SWIFTSHADER) - target_compile_definitions(dawn_native PRIVATE "DAWN_ENABLE_SWIFTSHADER") - endif() - -+if (IOS) -+ target_compile_options(dawn_native_objects PRIVATE -fno-objc-arc) -+ target_compile_options(dawn_native PRIVATE -fno-objc-arc) -+endif() -+ - if (DAWN_BUILD_MONOLITHIC_LIBRARY) - ############################################################################### - # Do the 'complete_lib' build. -diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm -index ce55acbd43..baa4835362 100644 ---- a/src/dawn/native/Surface_metal.mm -+++ b/src/dawn/native/Surface_metal.mm -@@ -36,7 +36,13 @@ - namespace dawn::native { - - bool InheritsFromCAMetalLayer(void* obj) { -- id object = static_cast(obj); -+ id object = -+#if TARGET_OS_IOS -+ (__bridge id)obj; -+#else -+ static_cast(obj); -+#endif -+ - return [object isKindOfClass:[CAMetalLayer class]]; - } - -diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm -index bde8bfea07..f2f6459e91 100644 ---- a/src/dawn/native/metal/SharedFenceMTL.mm -+++ b/src/dawn/native/metal/SharedFenceMTL.mm -@@ -40,7 +40,13 @@ ResultOrError> SharedFence::Create( - DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing."); - if (@available(macOS 10.14, iOS 12.0, *)) { - return AcquireRef(new SharedFence( -- device, label, static_cast>(descriptor->sharedEvent))); -+ device, label, -+#if TARGET_OS_IOS -+ (__bridge id)(descriptor->sharedEvent) -+#else -+ static_cast>(descriptor->sharedEvent) -+#endif -+ )); - } else { - return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported."); - } -diff --git a/src/tint/api/BUILD.cmake b/src/tint/api/BUILD.cmake -index 0037d83276..6372c4ee77 100644 ---- a/src/tint/api/BUILD.cmake -+++ b/src/tint/api/BUILD.cmake -@@ -57,6 +57,7 @@ tint_target_add_dependencies(tint_api lib - tint_lang_wgsl_ast_transform - tint_lang_wgsl_common - tint_lang_wgsl_features -+ tint_lang_wgsl_inspector - tint_lang_wgsl_program - tint_lang_wgsl_sem - tint_lang_wgsl_writer_ir_to_program diff --git a/cmake/patches/eigen/eigen-edge.patch b/cmake/patches/eigen/eigen-edge.patch new file mode 100644 index 0000000000000..d8dc850b4bd55 --- /dev/null +++ b/cmake/patches/eigen/eigen-edge.patch @@ -0,0 +1,13 @@ +diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h +index f85de305f..3dc2bb5e7 100644 +--- a/Eigen/src/Core/util/IndexedViewHelper.h ++++ b/Eigen/src/Core/util/IndexedViewHelper.h +@@ -178,7 +178,7 @@ namespace placeholders { + + EIGEN_DEPRECATED static const all_t all = Eigen::all; // PLEASE use Eigen::all instead of Eigen::placeholders::all + EIGEN_DEPRECATED static const last_t last = Eigen::last; // PLEASE use Eigen::last instead of Eigen::placeholders::last +- EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end ++ // EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end + } + + } // end namespace Eigen diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 162d33581a5ca..58697e293e583 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -86,3 +86,943 @@ index 0aab3e26..398ac2d6 100644 +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index c315a2a7..58963154 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA( + } + + auto& input_shape = getInputShape(ctx, 0); ++ if (input_shape.dim_size() < 2) { ++ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), "."); ++ } + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index be6a851d..fad595d0 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -126,6 +126,9 @@ void convPoolShapeInference( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA( + auto w_type = ctx.getInputType(3); + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType || + w_type->value_case() != TypeProto::kTensorType) { +- fail_type_inference("inputs are expected to have tensor type."); ++ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), "."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) { +- fail_type_inference("input and zero_point pair is expected to have be same type."); ++ fail_type_inference( ++ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), "."); + } + + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { +- fail_type_inference("weight and zero_point pair is expected to have same type."); ++ fail_type_inference( ++ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromInputToOutput(ctx, 7, 0); +@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasNInputShapes(ctx, 1)) { + return; + } +- auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); ++ ++ auto& input_shape = getInputShape(ctx, 0); + int64_t input_ndim = input_shape.dim_size(); + int64_t axis = -1; + auto axis_proto = ctx.getAttribute("axis"); +@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA( + // positive value. + axis += input_ndim; + } +- ++ if (axis < 0) { ++ fail_shape_inference( ++ "Unexpected axis value (", ++ axis, ++ ") rank of first input is ", ++ input_ndim, ++ " in ", ++ ctx.getDisplayName(), ++ "."); ++ } + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index 57f8e2a4..8b2dc07f 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h +index a80473b3..d1bcd401 100644 +--- a/onnx/defs/shape_inference.h ++++ b/onnx/defs/shape_inference.h +@@ -105,6 +105,10 @@ struct InferenceContext { + virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; + // Gets the shape inputs computed by partial data propagation. + virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; ++ // To display a name the user can use to narrow its search. ++ virtual std::string getDisplayName() const { ++ return ""; ++ } + }; + + // We use data propagation to perform partial evaluation of the model, to compute statically +@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput( + } else { + // This is not expected to happen + fail_type_inference( +- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case); ++ "Output ", ++ outputIndex, ++ " expected to have: ", ++ expected_value_case, ++ " or UNDEFINED. Got: ", ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr + const auto attr_type = attr->type(); + if (attr_type == AttributeProto::TENSOR) { + if (attr->t().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim tensor"); ++ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->t().data_type(); + expected_value_case = TypeProto::kTensorType; + } else if (attr_type == AttributeProto::SPARSE_TENSOR) { + if (attr->sparse_tensor().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim sparse tensor"); ++ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->sparse_tensor().values().data_type(); + expected_value_case = TypeProto::kSparseTensorType; + } else { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); +@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t + const auto* input_type = ctx.getInputType(n); + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); ++ } ++ if (!hasShape(*input_type)) { ++ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return input_type->tensor_type().shape(); +@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size + + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return &input_type->tensor_type().shape(); +@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + " does not match type of output: ", + outputIndex, + "type: ", +- output_value_case); ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + if (TypeProto::kTensorType == input_value_case) { + auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); +@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + *dim = input_type->sparse_tensor_type().shape().dim(static_cast(fromDimIndex)); + } else { + fail_type_inference( +- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type"); ++ "Input ", ++ inputIndex, ++ " and Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType + setTensorElementType(elemType, expected_type, *output_type); + } else { + // This is not expected to happen +- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type); ++ fail_type_inference( ++ "Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type: ", ++ expected_type, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput( + updateOutputElemType(ctx, outputIndex, default_value, expected_type); + return; + } else { +- fail_type_inference("Value of attribute ", attributeName, " not specified"); ++ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), "."); + } + } + if (!attr_proto->has_i()) { +- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type."); ++ fail_type_inference( ++ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), "."); + } + auto attr_value = attr_proto->i(); + auto elem_type = static_cast(attr_value); + if (!TensorProto_DataType_IsValid(elem_type)) { +- fail_type_inference("Attribute ", attributeName, " does not specify a valid type."); ++ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), "."); + } + updateOutputElemType(ctx, outputIndex, elem_type, expected_type); + } +@@ -497,7 +529,7 @@ inline TensorShapeProto* + getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { + auto output_type = ctx.getOutputType(n); + if (output_type == nullptr) { +- fail_type_inference("Output ", n, " expected to have tensor or sparse type"); ++ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), "."); + } + const auto output_value_case = output_type->value_case(); + if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { +@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ + } else if (output_value_case == TypeProto::VALUE_NOT_SET) { + return getTensorMutableShape(default_type, *output_type); + } else { +- fail_type_inference("Output ", n, " expected to have tensor type"); ++ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), "."); + } + } + +@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput( + auto attr_proto = ctx.getAttribute(attributeName); + if ((nullptr == attr_proto) || (!attr_proto->has_type()) || + (attr_proto->type() != AttributeProto_AttributeType_INTS)) { +- fail_shape_inference("Attribute ", attributeName, " should specify a shape"); ++ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), "."); + } + auto& int_list = attr_proto->ints(); + TensorShapeProto shape; + for (auto dim_size : int_list) { + if (dim_size < 0) { +- fail_shape_inference("Negative values are not allowed in a shape specification"); ++ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), "."); + } + shape.add_dim()->set_dim_value(dim_size); + } +@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect + if (hasInputShape(ctx, input_index)) { + auto rank = getInputShape(ctx, input_index).dim_size(); + if (rank != expected_rank) { +- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank); ++ fail_shape_inference( ++ "Input ", ++ input_index, ++ " expected to have rank ", ++ expected_rank, ++ " but has rank ", ++ rank, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + } +@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind + // This shape is expected to have rank > dim_index: + if (input_shape.dim_size() <= dim_index) { + fail_shape_inference( +- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size()); ++ "Input ", ++ input_index, ++ " expected to have rank >", ++ dim_index, ++ " but has rank ", ++ input_shape.dim_size(), ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + const Dim& input_dim = input_shape.dim(dim_index); + // Now, unify dim and input_dim: +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index 8723dcd4..8249fc59 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext { + const std::vector& input_types, + const std::vector& attributes, + const ShapeInferenceOptions& options) +- : input_types_(input_types), options_(options) { ++ : input_types_(input_types), options_(options), func_proto_(&func_proto) { + for (const auto& attr : attributes) { + attributesByName_[attr.name()] = &attr; + } +@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext { + return std::move(output_types_); + } + ++ std::string getDisplayName() const override { ++ if (func_proto_ == nullptr) ++ return ""; ++ if (func_proto_->domain().empty()) { ++ if (func_proto_->name().empty()) ++ return ""; ++ return MakeString("function ", func_proto_->name()); ++ } ++ if (func_proto_->name().empty()) ++ return MakeString("function [", func_proto_->domain(), "]"); ++ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]"); ++ } ++ + private: + const std::vector& input_types_; + std::vector output_types_; + std::unordered_map attributesByName_; + ShapeInferenceOptions options_; ++ const FunctionProto* func_proto_; + }; + + std::vector InferFunctionOutputTypes( +diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h +index 2c63c910..b0e4c32d 100644 +--- a/onnx/shape_inference/implementation.h ++++ b/onnx/shape_inference/implementation.h +@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext { + const ShapeInferenceOptions& options, + DataValueMap* generatedShapeData = nullptr, + GraphInferenceContext* graphInferenceContext = nullptr) +- : graphInferenceContext_{graphInferenceContext}, options_(options) { ++ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) { + for (auto& attr : *n.mutable_attribute()) { + attributesByName_[attr.name()] = &attr; + if (attr.has_g()) { +@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext { + return inferencer; + } + ++ std::string getDisplayName() const override { ++ if (node_ == nullptr) ++ return ""; ++ if (node_->domain().empty()) { ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type()); ++ return MakeString("node ", node_->op_type(), " (", node_->name(), ")"); ++ } ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]"); ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")"); ++ } ++ + std::vector allInputData_; + std::vector allInputSparseData_; + std::vector allShapeInputData_; +@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext { + // mutable as internal cache of GraphInferencer instances + mutable std::unordered_map> graphAttributeInferencers_; + ShapeInferenceOptions options_; ++ NodeProto* node_; + }; + + struct DataPropagationContextImpl : public DataPropagationContext { +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index ef379d8f..b7dfe3c8 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -2568,17 +2568,17 @@ ONNX_OPERATOR_SET_SCHEMA( + } + })); + +-void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equation) { +- const size_t numInputs = ctx.getNumInputs(); +- if (numInputs < 1 || !hasNInputShapes(ctx, static_cast(numInputs))) { ++void einsumShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string const& equation) { ++ // Only accept letters for indices ++ auto is_letter = [](char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); }; ++ ++ const size_t num_inputs = ctx.getNumInputs(); ++ if (num_inputs < 1 || !hasNInputShapes(ctx, static_cast(num_inputs))) { + return; + } +- +- auto* output_shape = getOutputShape(ctx, 0); ++ ONNX_NAMESPACE::TensorShapeProto output_shape; + std::string left_equation; + +- equation.erase(std::remove(equation.begin(), equation.end(), ' '), +- equation.end()); // Remove space char + auto mid_index = equation.find("->"); + if (mid_index != std::string::npos) { + // Separate right and left hand sides of the equation +@@ -2595,73 +2595,130 @@ void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equa + + // Parse the left-hand side + std::stringstream str(left_equation); ++ std::map label_maps; ++ std::set repeated_labels; ++ ONNX_NAMESPACE::TensorShapeProto dims_value, ellipsis_dims_value; ++ size_t num_labels = 0; ++ bool ellipsis_flag = true; ++ + while (!str.eof()) { + std::getline(str, term, ','); + auto ellipsis_index = term.find("..."); +- if (numInputs <= num_operands) { ++ if (num_inputs <= num_operands) { + fail_shape_inference("Number of input tensors does not match the operands in the equation."); + } +- size_t rank = ctx.getInputType(num_operands)->tensor_type().shape().dim_size(); ++ const auto& shape = ctx.getInputType(num_operands)->tensor_type().shape(); ++ size_t rank = shape.dim_size(); ++ size_t ellipsis_dims = 0; ++ ++ size_t term_size = 0; // number of legal indices for the current term ++ size_t num_illegal_char = 0; // number of illegal char before the current 'index' in the current term ++ ++ for (size_t index = 0; index < term.size(); ++index) { ++ if (is_letter(term[index])) { ++ term_size += 1; ++ } ++ } ++ ++ for (size_t index = 0; index < term.size(); ++index) { ++ if (index == ellipsis_index) { ++ // find ellipsis and record the dims represented by ellipsis ++ ellipsis_dims = rank - term_size; ++ if (ellipsis_flag) { ++ ellipsis_flag = false; ++ for (size_t i = 0; i < ellipsis_dims; i++) { ++ *ellipsis_dims_value.add_dim() = shape.dim(index + i - num_illegal_char); ++ } ++ } else { ++ for (size_t i = 0; i < ellipsis_dims; i++) { ++ const auto shape_dim = shape.dim(index + i - num_illegal_char); ++ const auto current_dim = ellipsis_dims_value.mutable_dim(i); ++ if (shape_dim.has_dim_value() && current_dim->has_dim_value() && ++ shape_dim.dim_value() > current_dim->dim_value() && current_dim->dim_value() == 1) { ++ current_dim->set_dim_value(shape_dim.dim_value()); ++ } ++ } ++ } ++ index += 2; // skip the rest of dots ++ num_illegal_char += 3; ++ continue; ++ ++ } else if (!is_letter(term[index])) { ++ num_illegal_char += 1; ++ continue; ++ } ++ ++ const auto inserted = label_maps.insert({term[index], num_labels}).second; ++ if (inserted) { ++ *dims_value.add_dim() = shape.dim(index + ellipsis_dims - num_illegal_char); ++ ++num_labels; ++ } else { ++ repeated_labels.insert(term[index]); ++ } ++ } ++ + if (ellipsis_index != std::string::npos) { + // If there is an ellipsis, the number of dimensions it represents + // must be total dim - letter dimensions + if (num_ellipsis == 0) { +- if (rank + 3 < term.size()) { ++ if (rank < term_size) { + fail_shape_inference("Ellipsis represents incompatible dimensions."); + } +- num_ellipsis_indices = rank - term.size() + 3; ++ num_ellipsis_indices = rank - term_size; + } else { // ellipsis has been seen before. Check that if dimensions + // are compatible +- if (num_ellipsis_indices != rank - term.size() + 3) { ++ if (num_ellipsis_indices != rank - term_size) { + fail_shape_inference("Ellipsis represents incompatible dimensions."); + } + } + num_ellipsis++; + } else { +- if (rank != term.size()) { ++ if (rank != term_size) { + fail_shape_inference("Rank of input ", num_operands, " does not match the equation indices."); + } + } + num_operands++; + } + +- if (numInputs != num_operands) { ++ if (num_inputs != num_operands) { + fail_shape_inference("Number of input tensors does not match the operands in the equation."); + } + +- const size_t number_of_letters = 26; +- size_t num_letter_occurrences[number_of_letters] = {0}; + // Parse the provided right-hand side + if (mid_index != std::string::npos) { + std::string right_equation = equation.substr(mid_index + 2); + auto right_ellipsis_index = right_equation.find("..."); +- if (right_ellipsis_index != std::string::npos) { // Right-hand side contains ellipsis +- for (size_t i = 0; i < num_ellipsis_indices; ++i) { +- output_shape->add_dim(); ++ ++ for (size_t index = 0; index < right_equation.size(); ++index) { ++ // If there's an ellipsis, add its corresponding dimensions ++ if (index == right_ellipsis_index) { ++ for (size_t i = 0; i < num_ellipsis_indices; i++) { ++ *output_shape.add_dim() = ellipsis_dims_value.dim(i); ++ } ++ index += 2; // skip the rest of dots ++ continue; + } +- } +- for (char c : right_equation) { // Add a dimension per each character +- // in right hand equation +- if (c != '.') { +- output_shape->add_dim(); ++ ++ if (is_letter(right_equation[index])) { ++ *output_shape.add_dim() = dims_value.dim(label_maps[right_equation[index]]); + } + } + } else { // Infer the dimension for right-hand side +- // If there's an ellipsis, add it's corresponding dimensions ++ // If there's an ellipsis, add its corresponding dimensions + for (size_t i = 0; i < num_ellipsis_indices; i++) { +- output_shape->add_dim(); ++ *output_shape.add_dim() = ellipsis_dims_value.dim(i); + } +- for (size_t i = 0; i < left_equation.size(); i++) { // Count chars that appear exactly once on left hand side +- if ((left_equation.at(i) != ',') && (left_equation.at(i) != '.')) { +- num_letter_occurrences[left_equation.at(i) - 'a']++; +- } +- } +- for (size_t index = 0; index < number_of_letters; index++) { +- if (num_letter_occurrences[index] == 1) { +- output_shape->add_dim(); ++ // If no explicit output was given, generate an implicit output by ordering all the ++ // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a). ++ // Exclude any labels that occurred more than once, as these cancel out. ++ for (auto i : label_maps) { ++ if (repeated_labels.count(i.first) == 0) { ++ *output_shape.add_dim() = dims_value.dim(i.second); + } + } + } ++ ++ updateOutputShape(ctx, 0, output_shape); + } + + static const char* Einsum_ver12_doc = R"DOC( +@@ -2711,7 +2768,10 @@ ONNX_OPERATOR_SET_SCHEMA( + if (equation.compare("") == 0) { + return; + } +- einsumRankInference(ctx, equation); ++ ++ equation.erase(std::remove(equation.begin(), equation.end(), ' '), ++ equation.end()); // Remove space char ++ einsumShapeInference(ctx, equation); + })); + + const char* reduction_doc_sce = +diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py +index 75280f6c..5543fda0 100644 +--- a/onnx/test/shape_inference_test.py ++++ b/onnx/test/shape_inference_test.py +@@ -7026,7 +7026,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x"], ["y"], equation="ij->ji")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None, None))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (4, 3))]) # type: ignore + + def test_einsum_dot(self) -> None: + graph = self._make_graph( +@@ -7050,7 +7050,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x", "y"], ["z"], equation="ij,ab->ijab")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None, None))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 5, 7, 9))]) # type: ignore + + def test_einsum_sum_along_dim(self) -> None: + graph = self._make_graph( +@@ -7058,7 +7058,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x"], ["y"], equation="i j->i ")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None,))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore + + def test_einsum_ellipsis(self) -> None: + graph = self._make_graph( +@@ -7066,26 +7066,36 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x"], ["y"], equation="... ii ->... i")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None, None))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 4))]) # type: ignore + + def test_einsum_ellipsis_2(self) -> None: + graph = self._make_graph( +- [("x", TensorProto.FLOAT, (2, 2, 2)), ("y", TensorProto.FLOAT, (2, 2, 2))], ++ [("x", TensorProto.FLOAT, (2, 3, 4)), ("y", TensorProto.FLOAT, (2, 4, 5))], + [make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk->...ik")], + [], + ) + self._assert_inferred( +- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))] ++ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 5))] + ) # type: ignore + + def test_einsum_ellipsis_3(self) -> None: + graph = self._make_graph( +- [("x", TensorProto.FLOAT, (2, 2, 2)), ("y", TensorProto.FLOAT, (2, 2, 2))], ++ [("x", TensorProto.FLOAT, (2, 3, 4)), ("y", TensorProto.FLOAT, (2, 4, 5))], + [make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk")], + [], + ) + self._assert_inferred( +- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))] ++ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 5))] ++ ) # type: ignore ++ ++ def test_einsum_ellipsis_broadcast(self) -> None: ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (1, 3, 4)), ("y", TensorProto.FLOAT, (32, 4, 5))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk->...ik")], ++ [], ++ ) ++ self._assert_inferred( ++ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (32, 3, 5))] + ) # type: ignore + + def test_einsum_contraction(self) -> None: +@@ -7099,11 +7109,7 @@ class TestShapeInference(TestShapeInferenceHelper): + ) + self._assert_inferred( + graph, +- [ +- make_tensor_value_info( +- "z", TensorProto.FLOAT, (None, None, None, None, None) +- ) +- ], ++ [make_tensor_value_info("z", TensorProto.FLOAT, (5, 6, 7, 9, 10))], + ) # type: ignore + + def test_einsum_contraction_2(self) -> None: +@@ -7113,7 +7119,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [], + ) + self._assert_inferred( +- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None))] ++ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (4, 5))] + ) # type: ignore + + def test_einsum_batch_matmul(self) -> None: +@@ -7122,7 +7128,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x", "y"], ["z"], equation="bij , b jk-> bik")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (5, 2, 4))]) # type: ignore + + def test_einsum_left_hand_eqn(self) -> None: + graph = self._make_graph( +@@ -7130,7 +7136,7 @@ class TestShapeInference(TestShapeInferenceHelper): + [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kl")], + [], + ) +- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None, None))]) # type: ignore ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3, 4))]) # type: ignore + + def test_einsum_incorrect_num_inputs(self) -> None: + graph = self._make_graph( +@@ -7144,6 +7150,244 @@ class TestShapeInference(TestShapeInferenceHelper): + ) + self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph) + ++ def test_einsum_view_A1(self) -> None: # returns a view of A1 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x"], ["y"], equation="i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_sum_A1(self) -> None: # sums the values of A1 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x"], ["y"], equation="i->")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore ++ ++ def test_einsum_element_wise_multiplication_A1_B1( ++ self, ++ ) -> None: # element-wise multiplication of A1 and B1 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="i,i->i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_inner_product_A1_B1(self) -> None: # inner product of A1 and B1 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="i,i->")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, ())]) # type: ignore ++ ++ def test_einsum_outer_product_A1_B1(self) -> None: # outer product of A1 and B1 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="i,j->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_view_A2(self) -> None: # returns a view of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ij->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_view_A2_2(self) -> None: # returns a view of A2, another case ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_transpose_A2(self) -> None: # view transpose of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ji")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_transpose_A2_to_ij(self) -> None: # view transpose of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ji->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_diag_A2(self) -> None: # view main diagonal of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ii->i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_trace_A2(self) -> None: # sums main diagonal of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ii->")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore ++ ++ def test_einsum_sum_A2(self) -> None: # sums the values of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ij->")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore ++ ++ def test_einsum_sum_columns_A2( ++ self, ++ ) -> None: # sum down the columns of A2 (across rows) ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ij->j")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_sum_rows_A2(self) -> None: # sum horizontally along the rows of A2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x"], ["y"], equation="ij->i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_element_wise_multiplication_A2_B2( ++ self, ++ ) -> None: # element-wise multiplication of A2 and B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,ij->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_element_wise_multiplication_A2_B2_transpose( ++ self, ++ ) -> None: # element-wise multiplication of A2 and B2.T ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,ji->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_matrix_multiplication_A2_B2( ++ self, ++ ) -> None: # matrix multiplication of A2 and B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,jk")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_matrix_multiplication_A2_B2_to_ik( ++ self, ++ ) -> None: # matrix multiplication of A2 and B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,jk->ik")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_matrix_multiplication_A3_B3( ++ self, ++ ) -> None: # matrix multiplication of A3 and B3 (a stack of 2D matrices) ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (2, 3, 3)), ("y", TensorProto.FLOAT, (2, 3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="bij,bjk->bik")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3))]) # type: ignore ++ ++ def test_einsum_matrix_multiplication_A3_B3_transpose( ++ self, ++ ) -> None: # matrix multiplication of A3 and B3 (a stack of 2D matrices) ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (2, 3, 3)), ("y", TensorProto.FLOAT, (2, 3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="bij,bkj->bik")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3))]) # type: ignore ++ ++ def test_einsum_inner_product_A2_B2(self) -> None: # inner product of A2 and B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kj->ik")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_row_multiplication_A2_B2( ++ self, ++ ) -> None: # each row of A2 multiplied by B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kj->ikj")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3, 3))]) # type: ignore ++ ++ def test_einsum_value_multiplication_A2_B2( ++ self, ++ ) -> None: # each value of A2 multiplied by B2 ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kl->ijkl")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3, 3, 3))]) # type: ignore ++ ++ def test_einsum_scalar_times_array(self) -> None: # Scalar times array ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, ()), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation=",ij->ij")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore ++ ++ def test_einsum_matrix_vector_A2_B1(self) -> None: # Matrix and vector. ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3,))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,j->i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_diag_multiplication_A2_B2( ++ self, ++ ) -> None: # diagonals multiplied by each other ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ii,ii->i")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore ++ ++ def test_einsum_diag_dot_product_A2_B2(self) -> None: # dot product of diagonals ++ graph = self._make_graph( ++ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))], ++ [make_node("Einsum", ["x", "y"], ["z"], equation="ii,ii->")], ++ [], ++ ) ++ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, ())]) # type: ignore ++ + def test_negative_log_likehood_shape_is_NCdd(self) -> None: + N, C = 3, 4 + graph = self._make_graph( diff --git a/cmake/target_delayload.cmake b/cmake/target_delayload.cmake index 53f252a3e71ac..92273f5424233 100644 --- a/cmake/target_delayload.cmake +++ b/cmake/target_delayload.cmake @@ -6,9 +6,12 @@ function(target_delayload target_name) if(NOT MSVC) message(SEND_ERROR "Delayloading is only supported in MSVC") endif() - foreach(lib ${ARGN}) - target_link_options(${target_name} PRIVATE /DELAYLOAD:"${lib}") - endforeach() + if(onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) + foreach(lib ${ARGN}) + target_link_options(${target_name} PRIVATE /DELAYLOAD:"${lib}") + endforeach() - target_link_libraries(${target_name} PRIVATE delayimp.lib) + target_link_libraries(${target_name} PRIVATE delayimp.lib) + endif() endfunction() + diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 159b8654c1cb1..fcb2c7d5de89b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -66,6 +66,12 @@ "platform": "windows" } ], + "overrides": [ + { + "name": "flatbuffers", + "version": "23.5.26" + } + ], "features": { "tests": { "description": "Build ONNXRuntime unit tests", diff --git a/csharp/ApiDocs/docfx.json b/csharp/ApiDocs/docfx.json index 0671d4aeb7d95..88a3283ad76e8 100644 --- a/csharp/ApiDocs/docfx.json +++ b/csharp/ApiDocs/docfx.json @@ -14,7 +14,7 @@ "disableDefaultFilter": false, "noRestore": true, "properties": { - "AllowUnsafeBlocks": true, + "AllowUnsafeBlocks": "true", "TargetFramework": "net8.0", "Nullable": "enable", "LangVersion": "8.0", diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 95207d158affe..6779fd60bcd0a 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -64,13 +64,6 @@ CMake creates a target to this project - - - - - - - @@ -153,7 +146,7 @@ CMake creates a target to this project $(BaseTargets);$(MobileTargets) + + + true + true + true + + + true + true + true + true + + $(ProjectDir)..\..\.. + + + true + + + Microsoft.ML.OnnxRuntime Microsoft.ML.OnnxRuntime @@ -66,54 +93,31 @@ Commit: $(BUILD_SOURCEVERSION) Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) + README.md + LICENSE.txt + + + true + + true + ..\..\OnnxRuntime.snk + + $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb + AnyCPU;x86 default true - true - ..\..\OnnxRuntime.snk - - $(ProjectDir)..\..\.. - $(OnnxRuntimeRoot)\csharp x64 false false portable - - true - - - true - - - - - false - $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb Debug;Release;RelWithDebInfo - - true - true - true - - - true - true - true - - - $(OnnxRuntimeCsharpRoot)\..\build\Linux - $(OnnxRuntimeBuildDirectory)\$(Configuration) - - - - $(OnnxRuntimeCsharpRoot)\..\build\Windows $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - - - - $(OnnxRuntimeCsharpRoot)\..\build\MacOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) - + $(OrtConstants);__MOBILE__ @@ -155,12 +148,12 @@ $(OrtConstants);__ANDROID__ - + $(OrtConstants);__IOS__ - - + + $(OrtConstants);__ENABLE_COREML__ @@ -178,128 +171,6 @@ $(DefineConstants);$(OrtConstants) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index be157a0419fc0..d628b065ceaa7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1142,9 +1142,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); - - [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] - public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tvm(IntPtr /*(OrtSessionOptions*) */ options, byte[] /*(char char*)*/ settings); #endif /// /// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance @@ -1272,7 +1269,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// /// Append an execution provider instance to the native OrtSessionOptions instance. /// - /// 'SNPE' and 'XNNPACK' are currently supported as providerName values. + /// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values. /// /// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs index b04f7886b76dd..1b9cd7572170b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs @@ -330,7 +330,8 @@ public enum CoreMLFlags : uint COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004, COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, COREML_FLAG_CREATE_MLPROGRAM = 0x010, - COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, + COREML_FLAG_USE_CPU_AND_GPU = 0x020, + COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU, } /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 3acd84b3016de..bd450451a1265 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -146,27 +146,6 @@ public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTPr } } - /// - /// A helper method to construct a SessionOptions object for TVM execution. - /// Use only if you have the onnxruntime package specific to this Execution Provider. - /// - /// settings string, comprises of comma separated key:value pairs. default is empty - /// A SessionsOptions() object configured for execution with TVM - public static SessionOptions MakeSessionOptionWithTvmProvider(String settings = "") - { - SessionOptions options = new SessionOptions(); - try - { - options.AppendExecutionProvider_Tvm(settings); - return options; - } - catch (Exception) - { - options.Dispose(); - throw; - } - } - /// /// A helper method to construct a SessionOptions object for ROCM execution. /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider. @@ -397,20 +376,6 @@ public void AppendExecutionProvider_CoreML(CoreMLFlags coremlFlags = CoreMLFlags #endif } - /// - /// Use only if you have the onnxruntime package specific to this Execution Provider. - /// - /// string with TVM specific settings - public void AppendExecutionProvider_Tvm(string settings = "") - { -#if __MOBILE__ - throw new NotSupportedException("The TVM Execution Provider is not supported in this build"); -#else - var utf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(settings); - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tvm(handle, utf8)); -#endif - } - private class ExecutionProviderAppender { private byte[] _utf8ProviderName; @@ -430,16 +395,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt /// /// Append QNN, SNPE or XNNPACK execution provider /// - /// Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported. + /// Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported. /// Optional key/value pairs to specify execution provider options. public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null) { - if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE") - { - throw new NotSupportedException( - "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method."); - } - if (providerOptions == null) { providerOptions = new Dictionary(); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 9b1df9357dc88..b4067806c5f93 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -76,7 +76,7 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); #endif - const uint ORT_API_VERSION = 20; + const uint ORT_API_VERSION = 21; #if NETSTANDARD2_0 IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION); api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index aa0e6ee62248a..17738da515134 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -146,10 +146,6 @@ public void TestSessionOptions() opt.AppendExecutionProvider_Nnapi(0); #endif -#if USE_TVM - opt.AppendExecutionProvider_Tvm("Vulkan -device=amd_apu"); -#endif - #if USE_OPENVINO opt.AppendExecutionProvider_OpenVINO(); #endif @@ -179,6 +175,12 @@ public void TestSessionOptions() ex = Assert.Throws(() => { opt.AppendExecutionProvider("QNN"); }); Assert.Contains("QNN execution provider is not supported in this build", ex.Message); #endif +#if USE_COREML + opt.AppendExecutionProvider("CoreML"); +#else + ex = Assert.Throws(() => { opt.AppendExecutionProvider("CoreML"); }); + Assert.Contains("CoreML execution provider is not supported in this build", ex.Message); +#endif opt.AppendExecutionProvider_CPU(1); } @@ -2041,7 +2043,7 @@ public SkipNonPackageTests() } // Test hangs on mobile. -#if !(ANDROID || IOS) +#if !(ANDROID || IOS) [Fact(DisplayName = "TestModelRunAsyncTask")] private async Task TestModelRunAsyncTask() { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj index 60d18ad31e811..07ca7fe7c64bf 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj @@ -1,16 +1,19 @@  + + true + true + true + + $(ProjectDir)..\..\.. + netstandard2.0;net8.0 false - $(ProjectDir)..\.. AnyCPU bin\$(Configuration)\ - true - true - true - $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx + $(OnnxRuntimeRoot)\cmake\external\onnx 8981 @@ -22,30 +25,22 @@ ..\..\OnnxRuntime.snk Debug;Release;RelWithDebInfo + Microsoft.ML.OnnxRuntime.Tests Microsoft.ML.OnnxRuntime.Tests.Common - - - $(OnnxRuntimeCsharpRoot)\..\build\Linux - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake - $(ProtocDirectory)\protoc - - - - $(OnnxRuntimeCsharpRoot)\..\build\Windows - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake\$(Configuration) $(ProtocDirectory)\protoc.exe + + $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake + $(ProtocDirectory)\protoc + + - - $(OnnxRuntimeCsharpRoot)\..\build\MacOS - $(OnnxRuntimeBuildDirectory)\$(Configuration) $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake $(ProtocDirectory)\protoc @@ -102,28 +97,6 @@ - - - - PreserveNewest - false - - - - PreserveNewest - false - - - - PreserveNewest - false - - - @@ -132,16 +105,20 @@ - + - + + - + + @@ -152,20 +129,20 @@ + - TestData\%(Filename)%(Extension) + TestData\%(Filename)%(Extension) - - TestData\overridable_initializer.onnx + + TestData\overridable_initializer.onnx - - TestData\capi_symbolic_dims.onnx + + TestData\capi_symbolic_dims.onnx - diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props new file mode 100644 index 0000000000000..3daab21dbcbac --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props @@ -0,0 +1,171 @@ + + + + + true + true + true + + + true + true + true + true + + + false + 1.20.0-dev-20241007 + + + + + + + + + + + + + + + + + + + + + + $(OnnxRuntimeRoot)\build\Windows + $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\Linux + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\MacOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\Android + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\iOS + iPhoneSimulator + $(Platform.ToLower()) + $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration)-$(PlatformLower) + + + + $(OnnxRuntimeRoot)\build\macOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + + PreserveNewest + true + + + + + + PreserveNewest + false + + + + + + PreserveNewest + false + + + + + + libs\libonnxruntime.so + + + + + + libs\libonnxruntime.dylib + Dynamic + True + True + + + + + + libs\libonnxruntime.dylib + Dynamic + True + True + + + + + + + + + false + true + false + true + false + true + + + + + + + + + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs index 27cde1dbe9ed8..46dd292e8514e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs @@ -2180,10 +2180,13 @@ public void GetArrayString(TensorConstructor constructor) {22,23} } }"; + // remove \r so the newlines are just \n on all platforms + expected = expected.Replace("\r", ""); + var actual= tensor.GetArrayString().Replace("\r", ""); - Assert.Equal(expected, tensor.GetArrayString()); + Assert.Equal(expected, actual); - var expectedNoSpace = expected.Replace(Environment.NewLine, "").Replace(" ", ""); + var expectedNoSpace = expected.Replace("\n", "").Replace(" ", ""); Assert.Equal(expectedNoSpace, tensor.GetArrayString(false)); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj index 210a04d78f107..e07448daeea7f 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj @@ -1,306 +1,125 @@  - - - true - true - true - true - $(ProjectDir)..\..\.. - - - - - net8.0-android;net8.0-ios;net8.0-maccatalyst - $(TargetFrameworks);net8.0-windows10.0.19041.0 - - - - - Exe - Microsoft.ML.OnnxRuntime.Tests.MAUI - true - true - enable - enable - true - - 8002 - - - $(DefineConstants);INCLUDE_FAILING_TESTS - $(DefineConstants);MODE_NON_INTERACTIVE_VISUAL - $(DefineConstants);MODE_XHARNESS - - - Microsoft.ML.OnnxRuntime.Tests.MAUI - - - ORT.CSharp.Tests.MAUI - - - 1.0 - 1 - - 15.0 - 13.1 - 30.0 - 10.0.17763.0 - 10.0.17763.0 - - true - ..\..\OnnxRuntime.snk - - - false - - - - - $(OnnxRuntimeRoot)\build\microsoft.ml.onnxruntime.1.18.1\runtimes - - true - - - - $(OnnxRuntimeRoot)\build\Windows - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - - $(PrebuiltRuntimesDir)\win-x64\native - - - $(OnnxRuntimeRoot)\build\Android - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(PrebuiltRuntimesDir)\android\native\onnxruntime.aar - - - $(OnnxRuntimeRoot)\build\iOS - iPhoneSimulator - $(Platform.ToLower()) - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration)-$(PlatformLower) - $(PrebuiltRuntimesDir)\ios\native\onnxruntime.xcframework - - - $(OnnxRuntimeRoot)\build\macOS - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(PrebuiltRuntimesDir)\ios\native\onnxruntime.xcframework - - - - - - PreserveNewest - true - - - - - PreserveNewest - true - - - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - - - - - libs\libonnxruntime.so - - - - - - - - - - libs\libonnxruntime.dylib - Dynamic - True - True - - - - - Framework - True - True - - - - - - - libs\libonnxruntime.dylib - Dynamic - True - True - - - - - Framework - True - True - - - - - - - false - true - false - true - false - true - - false - true - false - true - false - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - InferenceTest.cs - - - OrtIoBindingAllocationTest.cs - - - TensorTests.cs - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - <_VisualStudioTestRunnerFiles Include="@(PackagingOutputs)" Condition="$([System.String]::Copy('%(PackagingOutputs.FullPath)').Contains('xunit.runner.visualstudio'))" /> - - - + + $(ProjectDir)..\..\.. + + + + + + + net8.0-android;net8.0-ios;net8.0-maccatalyst + $(TargetFrameworks);net8.0-windows10.0.19041.0 + + + + + Exe + Microsoft.ML.OnnxRuntime.Tests.MAUI + true + true + enable + enable + true + + 8002 + + + $(DefineConstants);INCLUDE_FAILING_TESTS + $(DefineConstants);MODE_NON_INTERACTIVE_VISUAL + $(DefineConstants);MODE_XHARNESS + + + Microsoft.ML.OnnxRuntime.Tests.MAUI + + + ORT.CSharp.Tests.MAUI + + + 1.0 + 1 + + 15.0 + 13.1 + 30.0 + 10.0.17763.0 + 10.0.17763.0 + + true + ..\..\OnnxRuntime.snk + + + + + + + + + + + + + + + + + + + + + + + + InferenceTest.cs + + + OrtIoBindingAllocationTest.cs + + + TensorTests.cs + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_VisualStudioTestRunnerFiles + Include="@(PackagingOutputs)" + Condition="$([System.String]::Copy('%(PackagingOutputs.FullPath)').Contains('xunit.runner.visualstudio'))" /> + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md new file mode 100644 index 0000000000000..07cb5fe7c9b3d --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md @@ -0,0 +1,9 @@ +The MAUI test project can be optionally used with a pre-built ONNX Runtime native nuget package (Microsoft.ML.OnnxRuntime). + +To do so, specify the `UsePrebuiltNativePackage` and `CurrentOnnxRuntimeVersion` properties when building the project. These can be set via the command-line or as environment variables. + +For example: + +```cmd +dotnet build csharp\test\Microsoft.ML.OnnxRuntime.Tests.MAUI\Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj --property:UsePrebuiltNativePackage=true --property:CurrentOnnxRuntimeVersion=1.19.2 --source directory_containing_native_nuget_package --source https://api.nuget.org/v3/index.json +``` diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index b822c999e4d39..a8abcd2b4aa1c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -1,4 +1,9 @@  + + $(ProjectDir)..\..\.. + + + net8.0 @@ -6,9 +11,7 @@ $(ProjectDir)..\.. AnyCPU;x86 bin\$(Configuration)\ - true - true - true + $(OnnxSourceDirectory)\onnx default @@ -35,19 +38,19 @@ - $(OnnxRuntimeCsharpRoot)\..\build\Linux + $(OnnxRuntimeRoot)\build\Linux $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(OnnxRuntimeCsharpRoot)\..\build\Windows + $(OnnxRuntimeRoot)\build\Windows $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - $(OnnxRuntimeCsharpRoot)\..\build\MacOS + $(OnnxRuntimeRoot)\build\MacOS $(OnnxRuntimeBuildDirectory)\$(Configuration) @@ -58,15 +61,14 @@ PreserveNewest @@ -74,45 +76,39 @@ PreserveNewest false PreserveNewest false - - PreserveNewest - false - - + PreserveNewest false - - PreserveNewest - false - - + + PreserveNewest false - + + PreserveNewest false - + + PreserveNewest false + @@ -131,7 +127,7 @@ - + PreserveNewest false diff --git a/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist b/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist index 0004a4fdee5d5..fbb865624bbda 100644 --- a/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist +++ b/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist @@ -27,6 +27,6 @@ UIInterfaceOrientationLandscapeRight XSAppIconAssets - Assets.xcassets/appicon.appiconset + Assets.xcassets/onnxruntime_icon.appiconset diff --git a/dockerfiles/Dockerfile.cuda b/dockerfiles/Dockerfile.cuda index b5701eea82c6c..40f11dca623a7 100644 --- a/dockerfiles/Dockerfile.cuda +++ b/dockerfiles/Dockerfile.cuda @@ -2,16 +2,19 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------- -# Build onnxruntime-gpu python package with CUDA 12.6 & CUDNN 9.4 for python 3.12 in Ubuntu 24.04 for Nvidia GPU. +# Build onnxruntime-gpu python package with CUDA 12.x & CUDNN 9.x for python 3.12 in Ubuntu 24.04. # If memory is less than 64GB, you may change "--parallel" to "--parallel 4" to avoid out-of-memory error. -FROM nvcr.io/nvidia/cuda:12.6.1-devel-ubuntu24.04 +ARG CUDA_VERSION=12.6.1 +ARG CUDNN_VERSION=9.5.0.50 +ARG OS=ubuntu24.04 -# Target CUDA device with compute capability >= 6.1 +FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-${OS} +ARG CUDA_VERSION +ARG CUDNN_VERSION ARG CMAKE_CUDA_ARCHITECTURES="61;70;75;80;86;90" ENV DEBIAN_FRONTEND=noninteractive -MAINTAINER Changming Sun "chasun@microsoft.com" # Add source code to /code ADD . /code @@ -34,16 +37,18 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* # Install CUDNN 9.4.0.58 for building ONNX Runtime with CUDA. -RUN wget https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz \ +RUN cudnn_tar="cudnn-linux-x86_64-${CUDNN_VERSION}_cuda${CUDA_VERSION%%.*}-archive.tar.xz" \ + && wget "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${cudnn_tar}" \ && mkdir -p /code/build/cudnn \ - && tar -Jxvf cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz -C /code/build/cudnn --strip=1 + && tar -Jxvf ${cudnn_tar} -C /code/build/cudnn --strip=1 \ + && rm -f ${cudnn_tar} # Create a virtual environment and install dependencies, then build ONNX Runtime with CUDA support. RUN cd /code \ && python3 -m venv /code/env \ && . /code/env/bin/activate \ && pip install --upgrade psutil setuptools wheel packaging \ - && pip install -r tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt \ + && pip install -r /code/tools/ci_build/github/linux/python/requirements.txt \ && python /code/tools/ci_build/build.py --build_dir /code/build/Linux \ --allow_running_as_root --skip_submodule_sync \ --use_cuda --cuda_home /usr/local/cuda \ @@ -51,38 +56,55 @@ RUN cd /code \ --build_shared_lib --skip_tests \ --config Release --build_wheel --update --build --parallel \ --cmake_generator Ninja \ - --enable_cuda_nhwc_ops \ --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) "CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" onnxruntime_BUILD_UNIT_TESTS=OFF # Start second stage to copy the build artifacts -FROM nvcr.io/nvidia/cuda:12.6.1-runtime-ubuntu24.04 -ENV DEBIAN_FRONTEND=noninteractive +FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-${OS} +ARG CUDA_VERSION +ARG CUDNN_VERSION +ARG GIT_COMMIT +ARG GIT_BRANCH +ARG ONNXRUNTIME_VERSION + +# Make sure the required build arguments are set. See README.md for more information. +RUN test -n ${GIT_COMMIT:?} +RUN test -n ${GIT_BRANCH:?} +RUN test -n ${ONNXRUNTIME_VERSION:?} + +LABEL CUDA_VERSION="${CUDA_VERSION}" +LABEL CUDNN_VERSION="${CUDNN_VERSION}" +LABEL maintainer="Changming Sun " +LABEL onnxruntime_version="${ONNXRUNTIME_VERSION}" +LABEL onnxruntime_git_branch="${GIT_BRANCH}" +LABEL onnxruntime_git_commit="${GIT_COMMIT}" # Copy built wheel and license COPY --from=0 /code/build/Linux/Release/dist /ort COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt -# Set LD_LIBRARY_PATH so that runtime can load CUDA and CUDNN DLLs. -# CUDNN will be installed by nvidia-cudnn-cu12 python package later. -# Its location is in the site-packages directory, which can be retrieved like the following: -# python -c "import sysconfig; print(sysconfig.get_path('purelib'))" +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV CUDNN_VERSION=$CUDNN_VERSION +ENV ONNXRUNTIME_VERSION=$ONNXRUNTIME_VERSION +# CUDNN from nvidia-cudnn-cu12 python package is located in the site-packages directory of python virtual environment. ENV LD_LIBRARY_PATH="/ort/env/lib/python3.12/site-packages/nvidia/cudnn/lib:/usr/local/cuda/lib64" -# Install runtime dependencies, and run a simple test to verify the installation. +# Install runtime dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ libstdc++6 \ ca-certificates \ python3-pip \ python3.12-venv \ - unattended-upgrades \ - && unattended-upgrade \ && python3 -m venv /ort/env \ && . /ort/env/bin/activate \ && pip install /ort/*.whl \ - && pip install nvidia-cudnn-cu12==9.4.0.58 \ + && pip install nvidia-cudnn-cu${CUDA_VERSION%%.*}==${CUDNN_VERSION} \ && python -c 'import onnxruntime; print(onnxruntime.get_available_providers())' \ && rm -rf /ort/*.whl \ && rm -rf /var/lib/apt/lists/* # Ensure the virtual environment is always activated when running commands in the container. RUN echo ". /ort/env/bin/activate" >> ~/.bashrc + +# Set the default command to start an interactive bash shell +CMD [ "/bin/bash" ] diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index c3541a8bd3425..876a07e4ffaf6 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,12 +5,12 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} RUN apt-get update &&\ apt-get install -y migraphx diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 39e75a68a369f..d1ebdae3cbdd6 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -11,7 +11,7 @@ FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive ARG DEVICE=CPU ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git @@ -41,7 +41,7 @@ RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ @@ -50,7 +50,7 @@ ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER RUN usermod -a -G video,users ${BUILD_USER} -ENV WORKDIR_PATH /home/${BUILD_USER} +ENV WORKDIR_PATH=/home/${BUILD_USER} WORKDIR ${WORKDIR_PATH} USER ${BUILD_USER} diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index c242933f677f0..aca8c3feaff71 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,14 +5,14 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.tensorrt b/dockerfiles/Dockerfile.tensorrt index ef51d41c5ff1b..24947df6308a6 100644 --- a/dockerfiles/Dockerfile.tensorrt +++ b/dockerfiles/Dockerfile.tensorrt @@ -17,7 +17,7 @@ RUN apt-get update &&\ RUN unattended-upgrade WORKDIR /code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime with TensorRT RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.vitisai b/dockerfiles/Dockerfile.vitisai index e11ab70a61332..c6226155e01e3 100644 --- a/dockerfiles/Dockerfile.vitisai +++ b/dockerfiles/Dockerfile.vitisai @@ -22,8 +22,8 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:$PATH -ENV LD_LIBRARY_PATH /opt/xilinx/xrt/lib:$LD_LIBRARY_PATH +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/xilinx/xrt/lib:$LD_LIBRARY_PATH WORKDIR /code RUN . $VAI_ROOT/conda/etc/profile.d/conda.sh &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 008587a01082b..9f83fc390eee7 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -40,18 +40,33 @@ The docker file supports both x86_64 and ARM64(aarch64). You may use docker's "- However, we cannot build the code for 32-bit ARM in such a way since a 32-bit compiler/linker might not have enough memory to generate the binaries. ## CUDA -**Ubuntu 22.04, CUDA 12.1, CuDNN 8** +**Ubuntu 24.04, CUDA 12.x, CuDNN 9.x** 1. Build the docker image from the Dockerfile in this repository. + Choose available [cuda version](https://hub.docker.com/r/nvidia/cuda/tags) or [cudnn version](https://pypi.org/project/nvidia-cudnn-cu12/#history), then build docker image like the following: + ``` - docker build -t onnxruntime-cuda -f Dockerfile.cuda .. + git submodule update --init + docker build -t onnxruntime-cuda --build-arg CUDA_VERSION=12.6.1 \ + --build-arg CUDNN_VERSION=9.5.0.50 \ + --build-arg GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) \ + --build-arg GIT_COMMIT=$(git rev-parse HEAD) \ + --build-arg ONNXRUNTIME_VERSION=$(cat ../VERSION_NUMBER) \ + -f Dockerfile.cuda .. + ``` + To inspect the labels of the built image, run the following: + ``` + docker inspect onnxruntime-cuda + ``` 2. Run the Docker image ``` - docker run --gpus all -it onnxruntime-cuda + docker run --rm --gpus all -it onnxruntime-cuda + ``` or + ``` nvidia-docker run -it onnxruntime-cuda ``` @@ -277,7 +292,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropriate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm6.0, MIGraphX** +**Ubuntu 22.04, ROCm6.2.3, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -291,7 +306,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm6.0** +**Ubuntu 22.04, ROCm6.2.3** 1. Build the docker image from the Dockerfile in this repository. ``` diff --git a/docs/Coding_Conventions_and_Standards.md b/docs/Coding_Conventions_and_Standards.md index f18f1036efee8..02af7ddaa49be 100644 --- a/docs/Coding_Conventions_and_Standards.md +++ b/docs/Coding_Conventions_and_Standards.md @@ -164,22 +164,16 @@ dependencies to run linters locally. If you want to see what lintrunner init will install, run `lintrunner init --dry-run`. -To lint local changes: - -```bash -lintrunner -``` - -To format files and apply suggestions: +To format local changes: ```bash lintrunner -a ``` -To lint all files: +To format all files: ```bash -lintrunner --all-files +lintrunner -a --all-files ``` To show help text: diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index da4d0b7f66c37..6ea3f93cdea12 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1167,7 +1167,7 @@ This version of the operator has been available since version 1 of the 'com.micr
mask_index (optional) : M
Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
attention_bias (optional) : T
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_key (optional) : T
past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
past_value (optional) : T
@@ -1175,9 +1175,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.
beam_width (optional) : M
-
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.
cache_indirection (optional) : M
-
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
@@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
present_value (optional) : T
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
qk (optional) : V
-
normalized Q * K, of shape (batch_size, num_heads, 1, head_size).
+
normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length).
#### Type Constraints @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
beam_width (optional) : M
-
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.
cache_indirection (optional) : M
-
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration
#### Outputs @@ -1545,7 +1545,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.DynamicTimeWarping** - Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1). + Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time warping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1). #### Version @@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+
max_size : int
+
max size in the context. Usage depend on the EP.
notes : string
(Optional) Some notes for the model
onnx_model_filename : string
@@ -5974,7 +5976,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.UnfoldTensor** - Returns a tensor which contains all slices of size size from input tensor in the dimension dim. Step between two slices is given by step. If sizedim is the size of dimension dim for input tensor, the size of dimension dim in the returned tensor will be (sizedim - size) / step + 1. An additional dimension of size size is appended in the returned tensor. + Returns a tensor which contains all slices of size `size` from input tensor in the dimension `dim`. Step between two slices is given by `step`. If `sizedim` is the size of dimension `dim` for input tensor, the size of dimension `dim` in the returned tensor will be `(sizedim - size) / step + 1`. An additional dimension of size `size` is appended in the returned tensor. #### Version diff --git a/docs/How_To_Update_ONNX_Dev_Notes.md b/docs/How_To_Update_ONNX_Dev_Notes.md index 4d8a286bde66e..199e6671f6a1a 100644 --- a/docs/How_To_Update_ONNX_Dev_Notes.md +++ b/docs/How_To_Update_ONNX_Dev_Notes.md @@ -21,7 +21,7 @@ This file should be generated. See [cgmanifests/README](/cgmanifests/README.md) - [onnxruntime/test/python/requirements.txt](/onnxruntime/test/python/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/requirements.txt](/tools/ci_build/github/linux/docker/scripts/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt](/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt) -- [tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt](/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt) +- [tools/ci_build/github/linux/python/requirements.txt](/tools/ci_build/github/linux/python/requirements.txt) - Run `git grep -rn "onnx==1" .` to find other locations and update this document if necessary. 1. If there is any change to `cmake/external/onnx/onnx/*.in.proto`, you need to regenerate OnnxMl.cs. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a9176605d9175..eeb8ebb3ccefe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -258,7 +258,8 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**TS** = tensor(float)| +|||[10, 20]|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -452,6 +453,7 @@ Do not modify directly.* |SVMClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |SVMRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)| |Scaler|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|TreeEnsemble|*in* X:**T**
*out* Y:**T**|5+|**T** = tensor(double), tensor(float)| |TreeEnsembleClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |TreeEnsembleRegressor|*in* X:**T**
*out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)| @@ -468,9 +470,11 @@ Do not modify directly.* |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)| |DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| +|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)| |ExpandDims|*in* X:**T**
*in* axis:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**axis** = tensor(int32)| |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -518,6 +522,7 @@ Do not modify directly.* |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| +|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| |WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| @@ -550,8 +555,12 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -917,6 +926,35 @@ Do not modify directly.* |WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | +|**Operator Domain:** *com.ms.internal.nhwc*|||| +|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||10|**T** = tensor(float), tensor(float16)| +|||[7, 9]|**T** = tensor(float), tensor(float16)| +|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(double), tensor(float), tensor(float16)| +|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| +|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(float), tensor(float16)| +|ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(float), tensor(float16)| +|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)| +|||11|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||10|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||[8, 9]|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||[1, 7]|**T** = tensor(float), tensor(float16)| +|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +| | +| | @@ -965,7 +1003,8 @@ Do not modify directly.* |||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| @@ -979,7 +1018,8 @@ Do not modify directly.* |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| @@ -1017,7 +1057,8 @@ Do not modify directly.* |Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Flatten|*in* input:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1046,11 +1087,13 @@ Do not modify directly.* |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|GroupNorm||21+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|21+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1137,7 +1180,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||7+|**T** = tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1148,7 +1192,8 @@ Do not modify directly.* |||12+|**T** = tensor(float), tensor(float16), tensor(int32)
**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|||10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| |||19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| @@ -1249,7 +1294,8 @@ Do not modify directly.* |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -1289,7 +1335,8 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||10+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/docs/TVM_EP.md b/docs/TVM_EP.md deleted file mode 100644 index df59d5c05855c..0000000000000 --- a/docs/TVM_EP.md +++ /dev/null @@ -1,319 +0,0 @@ -# TVM Execution Provider - -## Contents - -- [Introduction](#introduction) -- [Build](#build-onnx-runtime-with-the-tvm-execution-provider) - - [Linux](#linux) - - [Windows](#windows) -- [Configuration options](#configuration-options) -- [Performance Tuning](#performance-tuning) - - [Using precompiled model](#using-precompiled-model) -- [Samples](#samples) -- [Known issues](#known-issues) - - -## Introduction - -TVM is an execution provider for ONNX Runtime that is built on top of Apache TVM. It enables ONNX Runtime users to leverage Apache TVM model optimizations. -TVM EP is currently in "Preview". It's been tested to work on a handful of models on Linux or Windows, but not on MacOS. - -## Build ONNX Runtime with the TVM Execution Provider - -### **Linux** -Install the minimal pre-requisites on Ubuntu/Debian like linux operating systems: -```bash -apt-get install -y python3 python3-dev python3-pip python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev llvm-12 -pip3 install numpy decorator attrs nasm -``` -Note: since ONNX Runtime with TVM EP is built with Intel ipp-crypto library there are new requirements. Compiler gcc (and g++) version should be equal to or higher than 8.2. nasm version should be 2.14.02 or higher. Problem with small nasm version can be seen [here](https://github.com/intel/ipp-crypto/issues/9) or [here](https://bugzilla.nasm.us/show_bug.cgi?id=3392205). For ubuntu LTS 18 `apt-get install nasm` is not enough due to it has version 2.13.02, see how to install from sources instruction [here](https://stackoverflow.com/questions/36144930/steps-to-install-nasm-offline-on-ubuntu). - -Also, the current implementation has `NVidia GPU` support for TVM EP. For now, you can use only `NVidia GPU` with CUDA Toolkit support. -To do this, make sure you have installed the NVidia driver and CUDA Toolkit. -More detailed instructions can be found on the [official page](https://developer.nvidia.com/cuda-toolkit). - -Clone this repo. -In order to build ONNXRT you will need to have CMake 3.18 or higher. In Ubuntu 20.04 you can use the following commands to install the latest version of CMake: - -```bash -sudo apt-get update -sudo apt-get install gpg wget - -wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null - -echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ focal main' | sudo tee /etc/apt/sources.list.d/kitware.list >/dev/null -sudo apt-get update - -sudo rm /usr/share/keyrings/kitware-archive-keyring.gpg -sudo apt-get install kitware-archive-keyring - -sudo apt-get install cmake -``` - -Build ONNX Runtime (TVM x86): -```bash -./build.sh --config Release --enable_pybind --build_wheel --parallel --skip_tests --skip_onnx_tests --use_tvm -``` - -Build ONNX Runtime (TVM with CUDA support): -```bash -./build.sh --config Release --enable_pybind --build_wheel --parallel --skip_tests --skip_onnx_tests --use_tvm --tvm_cuda_runtime -``` - -This command builds both `TVM` and `onnxruntime-tvm`. It creates two wheel, one for each project. -Build the python API for ONNX Runtime instead of using the standard package. Instructions for this are given below. - -Package for TVM: -```bash -cd -python3 -m pip uninstall tvm -y -whl_path=$(find ./build//Release/_deps/tvm-src/python/dist -name "*.whl") -python3 -m pip install $whl_path -``` - -Package for TVM EP: -```bash -cd -python3 -m pip uninstall onnxruntime onnxruntime-tvm -y -whl_path=$(find ./build//Release/dist -name "*.whl") -python3 -m pip install $whl_path -``` - -Alternatively, you can set `PYTHONPATH` to tell python where to find the ONNXRT library and the TVM library. -```bash -export PYTHONPATH=/build//Release:${PYTHONPATH} -export PYTHONPATH=/build//Release/_deps/tvm-src/python:${PYTHONPATH} -``` - -### **Windows** -Install the minimal prerequisites on Windows: Git, CMake, Visual Studio, Python, LLVM -- Git: Download Git for Windows from [here](https://git-scm.com/download/win) and install it. Please make sure that the git.exe path is included in the environment variable. By default, it should be added. To check git after the installation use `git --version` in command line (cmd). -- CMake: use [the link](https://cmake.org/download/) to download and install CMake. msi-file is recommended for it. To verify CMake installation use `cmake --version` in cmd. -- Visual Studio: Download from [here](https://visualstudio.microsoft.com/ru/downloads/) and install Visual Studio 20** Community & Visual Studio Build Tools respectively. It is recommended not to change the default installation path. Chose "Desktop development with C++" workload and make sure that both options of “MSVC [contemporary version] C++ build tools” and “Windows 10 SDK” are selected. -- Python: Download Python 3.* from [here](https://www.python.org/downloads/windows/) and install it. Please have a check on the option of “Add Python to PATH”, so the installer will include the Python directory into the environment variable directly. To check python after the installation use `python` from cmd. The expected output is similar to the following: -```cmd -Python 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] on win32 -Type "help", "copyright", "credits" or "license" for more information. ->>> -``` -Use `quit()` to exit from python interface. -- LLVM: the compiler is not necessary for pure ONNX Runtime installation but it is needed for TVM EP by default. -```cmd -git clone --depth 1 --branch release/11.x https://github.com/llvm/llvm-project.git -cmake -S llvm -B build -DLLVM_ENABLE_PROJECTS="clang;libcxx;libcxxabi" -DLLVM_TARGETS_TO_BUILD=X86 -Thost=x64 -DCMAKE_BUILD_TYPE=Release -G "Visual Studio 17 2022" -cmake --build ./build --config Release -``` -- Dependencies of ipp-crypto:
-1. install asm compiler (nasm) on windows by line: -```cmd -winget install nasm -i -``` -          -Add it to PATH (instruction for Windows GUI can be seen [here](https://www.computerhope.com/issues/ch000549.htm#dospath)) or by cmd: -```cmd -set PATH="%PATH%;C:\Program Files\NASM" -``` -          -or -```cmd -setx PATH "%PATH%;C:\Program Files\NASM" -``` -          -Check by `nasm --version` in prompt command line.
-       -2. install openssl on windows by msi-file from [here](https://slproweb.com/products/Win32OpenSSL.html) -Add path to directory (e.g. "C:\Program Files\OpenSSL-Win64\bin") with executable file to PATH (see instructions above).
-          -Check by `openssl version` in prompt command line.
-       -3. Correct build of ipp-crytpo requires specific environment variables for supported MSVC compiler. Long way to adjust the environment is to follow to instructions [here](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-170&viewFallbackFrom=vs-2017). Quick way is to use VS Developer command prompt where the environment have been already adjusted or add some paths to standard Windows command prompt: -```cmd -set INCLUDE=C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.32.31326\include;C:\Program Files (x86)\Windows Kits\10\include\10.0.22621.0\ucrt -``` -          -Take into account that MSVC and Kit versions are specific for Visual Studio built on the machine, specified values here are used as example. -
-
- -For using NVIDIA GPU (optional) CUDA and cuDNN should be installed. -- CUDA: Install CUDA by the [link](https://developer.nvidia.com/cuda-11.0-download-archive). -- cuDNN: download cuDNN installer from [here](https://developer.nvidia.com/rdp/cudnn-archive). Choose v8.* for corresponding CUDA v11.*, unzip it, and move cuDNN files as following: -1. [unzipped dir]\bin\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\bin -2. [unzipped dir]\include\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\include -3. [unzipped dir]\lib\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\lib - -To verify the CUDA installation use `nvcc --version` in cmd. -
-
- -#### **Build ONNX Runtime with TVM Execution Provider from source (Python):** -- Use command line and clone sources from github: -```cmd -git clone --recursive https://github.com/Microsoft/onnxruntime -cd onnxruntime -``` -- CPU build: -``` -build.bat --config Release --enable_pybind --build_wheel --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config /build/Release/bin/llvm-config.exe -``` -- GPU build: -``` -build.bat --config Release --enable_pybind --build_wheel --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config /build/Release/bin/llvm-config.exe --use_cuda --cudnn_home “C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.*” --cuda_home “C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.*” -``` -In both cases (CPU, GPU) there are the following options for cmake generator: "Visual Studio 17 2022" and "Ninja". Also handshake mechanism can be switched on by `--use_tvm_hash` flag. At the latter case ipp-crypto library is built with dependencies, see details above. -- Install python wheel package for ONNX Runtime:
-Default path to the package is `/build/Windows/Release/Release/dist`. Note that it is different in comparison with path to the package on Linux. Before installation check names of wheel packages and use corresponding one. It can be looked like the following: -```cmd -python -m pip install .\onnxruntime\build\Windows\Release\Release\dist\onnxruntime_tvm-1.6.0-cp38-cp38-win_amd64.whl -``` -- Install python wheel package for TVM due to its python API is used inside TVM EP:
-It can be looked like the following: -```cmd -python -m pip install .\onnxruntime\build\Windows\Release\_deps\tvm-src\python\dist\tvm-0.9.dev1728+g3425ed846-cp39-cp39-win_amd64.whl -``` -- Verify result by python script. Note: python should not be launched from directory containing 'onnxruntime' directory for correct result: -```python -import onnxruntime -print(onnxruntime.__version__) -print(onnxruntime.get_device()) -print(onnxruntime.get_available_providers()) -``` -- Uninstall procedure: -```cmd -pip uninstall onnxruntime-tvm -``` - -#### **Build ONNX Runtime with TVM Execution Provider from source (C#):** -- Use command line and clone sources from github: -```cmd -git clone --recursive https://github.com/Microsoft/onnxruntime -cd onnxruntime -``` -- CPU build: - -Make sure you download [nuget.exe](https://docs.microsoft.com/en-us/nuget/install-nuget-client-tools#nugetexe-cli) and add path to it into `PATH` env. -``` -build.bat --config Release --build_nuget --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config llvm-config.exe -``` -- Install C# nuget package for TVM EP. Default path to the package is `\build\Windows\Release\Release`. - - -## Configuration options -TVM Executor Provider can be configured with the following provider options: -1. Python -```python -po = [dict(executor=tvm_executor_type, - so_folder=folder_with_pretuned_files, - check_hash=check_hash, - hash_file_path=hash_file_path, - target=client_target, - target_host=client_target_host, - opt_level=client_opt_level, - freeze_weights=freeze, - to_nhwc=layout_transform, - tuning_type=tvm_optimizer_type, - tuning_file_path=client_tuning_logfile, - input_names = input_names_str, - input_shapes = input_shapes_str)] -tvm_session = onnxruntime.InferenceSession(model_path, providers=["TvmExecutionProvider"], provider_options=po) -``` - -2. C# - -Currently, only precompiled models are supported in C# (see the related section below). - -```CSharp -SessionOptions session_options = new SessionOptions{}; -string tvm_ep_options = - $"executor: {tvm_executor_type}, " + - $"so_folder: {folder_with_pretuned_files}, " + - $"check_hash: {check_hash}, " + - $"hash_file_path: {hash_file_path}, " + - $"target: {client_target}, " + - $"target_host: {client_target_host}, " + - $"opt_level: {client_opt_level}, " + - $"freeze_weights: {freeze}, " + - $"to_nhwc: {layout_transform}, " + - $"tuning_type: {tvm_optimizer_type}, " + - $"tuning_file_path: {client_tuning_logfile}, " + - $"input_names: {input_names_str}, " + - $"input_shapes: {input_shapes_str}"; - -session_options.AppendExecutionProvider_Tvm(tvm_ep_options); -using var tvm_session = new InferenceSession(modelFilePath, session_options); -``` -
- -- `executor` is executor type used by TVM. There is choice between two types: GraphExecutor and VirtualMachine which are corresponded to "graph" and "vm" tags. VirtualMachine is used by default. -- `so_folder` is path to folder with set of files (.ro-, .so/.dll-files and weights) obtained after model tuning. It uses these files for executor compilation instead of onnx-model. But the latter is still needed for ONNX Runtime. -- `check_hash` means that it is necessary to perform a HASH check for the model obtained in the `so_folder` parameter. It is `False` by default. -- `hash_file_path` is path to file that contains the pre-computed HASH for the ONNX model which result of tuning locates in the path passed by `so_folder` parameter. - If an empty string was passed as this value, then the file will be searched in the folder that was passed in the `so_folder` parameter. -- `target` and `target_host` are strings like in TVM (e.g. "llvm --mcpu=avx2"). When using accelerators, target may be something like `cuda` while target_host may be `llvm -mtriple=x86_64-linux-gnu` -- `opt_level` is TVM optimization level. It is 3 by default -- `freeze_weights` means that all model weights are kept on compilation stage otherwise they are downloaded each inference. True is recommended value for the best performance. It is true by default. -- `to_nhwc` switches on special model transformations, particularly data layout, which Octomizer is used. It allows to work correctly with tuning logs obtained from Octomizer. It is false by default. -- `tuning_type` defines the type of TVM tuning logs being used, and can be set to either `AutoTVM` (1st gen auto tuning logs) or `Ansor` (2nd gen auto tuning logs). By default this option is set to `AutoTVM`. -- `tuning_file_path` is path to AutoTVM or Ansor tuning file which gives specifications for given model and target for the best performance. (See below for more details). - -TVM supports models with fixed graph only. If your model has unknown dimensions in input shapes (excluding batch size) you must provide the shape using the `input_names` and `input_shapes` provider options. Below is an example of what must be passed to `provider_options`: -```python -input_names = "input_1 input_2" -input_shapes = "[1 3 224 224] [1 2]" -``` - -## Performance Tuning -TVM optimizes machine learning models through an automated tuning process that produces model variants specific to targeted hardware architectures. This process also generates 'tuning logs' that the TVM EP relies on to maximize model performance. These logs can be acquired for your model by either using TVM as described here: - -AutoTVM: -https://tvm.apache.org/docs/how_to/tune_with_autotvm/index.html - -Ansor (Autoscheduling): -https://tvm.apache.org/docs/how_to/tune_with_autoscheduler/index.html - -or by using logs generated through the OctoML platform (https://onnx.octoml.ai) using instructions [here](https://help.octoml.ai/en/articles/5814452-using-octoml-platform-logs-with-onnx-rt-tvm-ep) - -Using the TVM EP with TVM tuning logs also requires users to turn off ONNX Runtime preprocessing. To do this, the following `SessionOptions()` can be used: -``` -so = onnxruntime.SessionOptions() -so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - -tvm_session = onnxruntime.InferenceSession(model_path, sess_options=so, providers=["TvmExecutionProvider"], provider_options=po) -``` - -### **Using precompiled model** -It is also possible to use a precompiled model. - -The compiled model can be obtained using the [OctoML platform](https://onnx.octoml.ai) -or compiled directly (see **Support precompiled model** section in -[Sample notebook for ResNet50 inference with TVM EP](https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb) -for more information on model compilation). - -In order to use the precompiled model, only need to pass two options: -* **executor** - `vm` (`VirtualMachine`) must be used as a value -(this functionality is not supported for `GraphExecutor`); -* **so_folder** - as a value, you must pass the path to the directory where -the files of the precompiled model are located. -* **check_hash** - (optional) if you want to check hash, you must pass `True` as the value. -* **hash_file_path** - (optional) by default, the file containing the hash for the tuned model will be searched in the directory that is passed in the `so_folder` parameter. - If you want to specify different location, then you must pass the path to the file that contains the desired hash as a value. - -You can read more about these options in section [Configuration options](#configuration-options) above. - - -## Samples -- [Sample notebook for ResNet50 inference with TVM EP](https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb) - -## Known issues -- At this moment, the TVM EP has only been verified on UNIX/Linux and Windows systems. -- Some compatibility issues have been found between ONNX and Google protobuf. `AttributeError: module 'google.protobuf.internal.containers' has no attribute 'MutableMapping'`. This usually occurss during `import onnx` in any python scripts for protobuf version >= 3.19.0 and ONNX version <= 1.8.1. To resolve the issue Google protobuf and ONNX can be reinstalled separately or together using: -``` -pip3 uninstall onnx -y -pip3 install onnx==1.10.1 -pip3 uninstall protobuf -y -pip3 install protobuf==3.19.1 -``` - -The following pair of ONNX and protobuf versions have been found to be compatible: -- 3.17.3 and 1.8.0 -- 3.19.1 and 1.10.1 diff --git a/docs/python/README.rst b/docs/python/README.rst index 5a45bf6cef8ed..cce966f7d7d0c 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_. +Some onnx data type (like TensorProto.BFLOAT16, TensorProto.FLOAT8E4M3FN and TensorProto.FLOAT8E5M2) are not supported by Numpy. You can directly bind input or output with Torch tensor of corresponding data type +(like torch.bfloat16, torch.float8_e4m3fn and torch.float8_e5m2) in GPU memory. + +.. code-block:: python + + x = torch.ones([3], dtype=torch.float8_e5m2, device='cuda:0') + y = torch.empty([3], dtype=torch.bfloat16, device='cuda:0') + + binding = session.io_binding() + binding.bind_input( + name='X', + device_type='cuda', + device_id=0, + element_type=TensorProto.FLOAT8E5M2, + shape=tuple(x.shape), + buffer_ptr=x.data_ptr(), + ) + binding.bind_output( + name='Y', + device_type='cuda', + device_id=0, + element_type=TensorProto.BFLOAT16, + shape=tuple(y.shape), + buffer_ptr=y.data_ptr(), + ) + session.run_with_iobinding(binding) + API Details =========== diff --git a/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb b/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb deleted file mode 100644 index 830495bdfb98d..0000000000000 --- a/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb +++ /dev/null @@ -1,657 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "72476497", - "metadata": {}, - "source": [ - "# ONNX Runtime: Tutorial for TVM execution provider\n", - "\n", - "This notebook shows a simple example for model inference with TVM EP.\n", - "\n", - "\n", - "#### Tutorial Roadmap:\n", - "1. Prerequistes\n", - "2. Accuracy check for TVM EP\n", - "3. Configuration options\n", - "4. Support precompiled model" - ] - }, - { - "cell_type": "markdown", - "id": "9345cbab", - "metadata": {}, - "source": [ - "## 1. Prerequistes\n", - "\n", - "Make sure that you have installed all the necessary dependencies described in the corresponding paragraph of the documentation.\n", - "\n", - "Also, make sure you have the `tvm` and `onnxruntime-tvm` packages in your pip environment. \n", - "\n", - "If you are using `PYTHONPATH` variable expansion, make sure it contains the following paths: `/onnxruntime/cmake/external/tvm_update/python` and `/onnxruntime/build/Linux/Release`." - ] - }, - { - "cell_type": "markdown", - "id": "da4ca21f", - "metadata": {}, - "source": [ - "### Common import\n", - "\n", - "These packages can be delivered from standard `pip`." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0f072875", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import onnx\n", - "import tempfile\n", - "import numpy as np\n", - "from typing import List, AnyStr\n", - "from onnx import ModelProto, helper, checker, mapping" - ] - }, - { - "cell_type": "markdown", - "id": "118670aa", - "metadata": {}, - "source": [ - "### Specialized import\n", - "\n", - "It is better to collect these packages from source code in order to clearly understand what is available to you right now." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a5502966", - "metadata": {}, - "outputs": [], - "source": [ - "import onnxruntime\n", - "\n", - "import tvm\n", - "import tvm.relay\n", - "import tvm.testing\n", - "import tvm.runtime\n", - "import tvm.runtime.vm\n", - "import tvm.relay.backend.vm\n", - "import tvm.contrib.download" - ] - }, - { - "cell_type": "markdown", - "id": "b7313183", - "metadata": {}, - "source": [ - "### Helper functions for working with ONNX ModelProto\n", - "\n", - "This set of helper functions allows you to recognize the meta information of the models. This information is needed for more versatile processing of ONNX models." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "7d0a36e8", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnx_input_names(model: ModelProto) -> List[AnyStr]:\n", - " inputs = [node.name for node in model.graph.input]\n", - " initializer = [node.name for node in model.graph.initializer]\n", - " inputs = list(set(inputs) - set(initializer))\n", - " return sorted(inputs)\n", - "\n", - "\n", - "def get_onnx_output_names(model: ModelProto) -> List[AnyStr]:\n", - " return [node.name for node in model.graph.output]\n", - "\n", - "\n", - "def get_onnx_input_types(model: ModelProto) -> List[np.dtype]:\n", - " input_names = get_onnx_input_names(model)\n", - " return [\n", - " mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type]\n", - " for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n", - " ]\n", - "\n", - "\n", - "def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]:\n", - " input_names = get_onnx_input_names(model)\n", - " return [\n", - " [dv.dim_value for dv in node.type.tensor_type.shape.dim]\n", - " for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n", - " ]\n", - "\n", - "\n", - "def get_random_model_inputs(model: ModelProto) -> List[np.ndarray]:\n", - " input_shapes = get_onnx_input_shapes(model)\n", - " input_types = get_onnx_input_types(model)\n", - " assert len(input_types) == len(input_shapes)\n", - " inputs = [np.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)]\n", - " return inputs" - ] - }, - { - "cell_type": "markdown", - "id": "f0de1682", - "metadata": {}, - "source": [ - "### Wrapper helper functions for Inference\n", - "\n", - "Wrapper helper functions for running model inference using ONNX Runtime EP." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "258ce9e9", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnxruntime_output(model: ModelProto, inputs: List, provider_name: AnyStr) -> np.ndarray:\n", - " output_names = get_onnx_output_names(model)\n", - " input_names = get_onnx_input_names(model)\n", - " assert len(input_names) == len(inputs)\n", - " input_dict = {input_name: input_value for input_name, input_value in zip(input_names, inputs)}\n", - "\n", - " inference_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider_name])\n", - " output = inference_session.run(output_names, input_dict)\n", - "\n", - " # Unpack output if there's only a single value.\n", - " if len(output) == 1:\n", - " output = output[0]\n", - " return output\n", - "\n", - "\n", - "def get_cpu_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n", - " return get_onnxruntime_output(model, inputs, \"CPUExecutionProvider\")\n", - "\n", - "\n", - "def get_tvm_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n", - " return get_onnxruntime_output(model, inputs, \"TvmExecutionProvider\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc17d3b2", - "metadata": {}, - "source": [ - "### Helper function for checking accuracy\n", - "\n", - "This function uses the TVM API to compare two output tensors. The tensor obtained using the `CPUExecutionProvider` is used as a reference.\n", - "\n", - "If a mismatch is found between tensors, an appropriate exception will be thrown." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4e598907", - "metadata": {}, - "outputs": [], - "source": [ - "def verify_outputs(\n", - " lhs: List[np.ndarray],\n", - " rhs: List[np.ndarray],\n", - " rtol: float = 5e-5,\n", - " atol: float = 5e-5\n", - ") -> None:\n", - " for lhs_tensor, rhs_tensor in zip(lhs, rhs):\n", - " tvm.testing.assert_allclose(lhs_tensor, rhs_tensor, rtol=rtol, atol=atol)\n", - " assert lhs_tensor.dtype == rhs_tensor.dtype\n", - " print(\"Same output, congratulations!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f33a372b", - "metadata": {}, - "outputs": [], - "source": [ - "def verify_with_ort_with_inputs(\n", - " model,\n", - " inputs,\n", - " out_shape=None,\n", - " opset=None,\n", - " freeze_params=False,\n", - " dtype=\"float32\",\n", - " rtol=1e-5,\n", - " atol=1e-5,\n", - " opt_level=1,\n", - "):\n", - " if opset is not None:\n", - " model.opset_import[0].version = opset\n", - "\n", - " ort_out = get_cpu_onnxruntime_output(model, inputs)\n", - " tvm_out = get_tvm_onnxruntime_output(model, inputs)\n", - " verify_outputs(ort_out, tvm_out, rtol, atol)" - ] - }, - { - "cell_type": "markdown", - "id": "8c62b01a", - "metadata": {}, - "source": [ - "### Helper functions for download models\n", - "\n", - "These functions use the TVM API to download models from the ONNX Model Zoo." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "324c00e7", - "metadata": {}, - "outputs": [], - "source": [ - "BASE_MODEL_URL = \"https://github.com/onnx/models/raw/master/\"\n", - "MODEL_URL_COLLECTION = {\n", - " \"ResNet50-v1\": \"vision/classification/resnet/model/resnet50-v1-7.onnx\",\n", - " \"ResNet50-v2\": \"vision/classification/resnet/model/resnet50-v2-7.onnx\",\n", - " \"SqueezeNet-v1.1\": \"vision/classification/squeezenet/model/squeezenet1.1-7.onnx\",\n", - " \"SqueezeNet-v1.0\": \"vision/classification/squeezenet/model/squeezenet1.0-7.onnx\",\n", - " \"Inception-v1\": \"vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx\",\n", - " \"Inception-v2\": \"vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx\",\n", - "}\n", - "\n", - "\n", - "def get_model_url(model_name):\n", - " return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name]\n", - "\n", - "\n", - "def get_name_from_url(url):\n", - " return url[url.rfind(\"/\") + 1 :].strip()\n", - "\n", - "\n", - "def find_of_download(model_name):\n", - " model_url = get_model_url(model_name)\n", - " model_file_name = get_name_from_url(model_url)\n", - " return tvm.contrib.download.download_testdata(model_url, model_file_name, module=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "90fb7c5c", - "metadata": {}, - "source": [ - "## 2. Accuracy check for TVM EP \n", - "\n", - "This section will check the accuracy. The check will be to compare the output tensors for `CPUExecutionProvider` and `TvmExecutionProvider`. See the description of `verify_with_ort_with_inputs` function used above.\n", - "\n", - "\n", - "### Check for simple architectures" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c739ed5c", - "metadata": {}, - "outputs": [], - "source": [ - "def get_two_input_model(op_name: AnyStr) -> ModelProto:\n", - " dtype = \"float32\"\n", - " in_shape = [1, 2, 3, 3]\n", - " in_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]\n", - " out_shape = in_shape\n", - " out_type = in_type\n", - "\n", - " layer = helper.make_node(op_name, [\"in1\", \"in2\"], [\"out\"])\n", - " graph = helper.make_graph(\n", - " [layer],\n", - " \"two_input_test\",\n", - " inputs=[\n", - " helper.make_tensor_value_info(\"in1\", in_type, in_shape),\n", - " helper.make_tensor_value_info(\"in2\", in_type, in_shape),\n", - " ],\n", - " outputs=[\n", - " helper.make_tensor_value_info(\n", - " \"out\", out_type, out_shape\n", - " )\n", - " ],\n", - " )\n", - " model = helper.make_model(graph, producer_name=\"two_input_test\")\n", - " checker.check_model(model, full_check=True)\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7048ee6d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n", - "****************** Success! ******************\n" - ] - } - ], - "source": [ - "onnx_model = get_two_input_model(\"Add\")\n", - "inputs = get_random_model_inputs(onnx_model)\n", - "verify_with_ort_with_inputs(onnx_model, inputs)\n", - "print(\"****************** Success! ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "52c880f4", - "metadata": {}, - "source": [ - "### Check for DNN architectures " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "f5d465dc", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnx_model(model_name):\n", - " model_path = find_of_download(model_name)\n", - " onnx_model = onnx.load(model_path)\n", - " return onnx_model" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "68daac7e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n", - "****************** Success! ******************\n" - ] - } - ], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "\n", - "onnx_model = get_onnx_model(model_name)\n", - "inputs = get_random_model_inputs(onnx_model)\n", - "verify_with_ort_with_inputs(onnx_model, inputs)\n", - "print(\"****************** Success! ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "e27f64a2", - "metadata": {}, - "source": [ - "## 3. Configuration options\n", - "\n", - "This section shows how you can configure TVM EP using custom options. For more details on the options used, see the corresponding section of the documentation." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "a053f59f", - "metadata": {}, - "outputs": [], - "source": [ - "provider_name = \"TvmExecutionProvider\"\n", - "provider_options = dict(\n", - " target=\"llvm -mtriple=x86_64-linux-gnu\",\n", - " target_host=\"llvm -mtriple=x86_64-linux-gnu\",\n", - " opt_level=3,\n", - " freeze_weights=True,\n", - " tuning_file_path=\"\",\n", - " tuning_type=\"Ansor\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "3f6e6f01", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "onnx_model = get_onnx_model(model_name)\n", - "input_dict = {\n", - " input_name: input_value for input_name, input_value in zip(\n", - " get_onnx_input_names(onnx_model),\n", - " get_random_model_inputs(onnx_model),\n", - " )\n", - "}\n", - "output_names = get_onnx_output_names(onnx_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "85ab83f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "****************** Output shape: (1, 1000) ******************\n" - ] - } - ], - "source": [ - "tvm_session = onnxruntime.InferenceSession(\n", - " onnx_model.SerializeToString(),\n", - " providers=[provider_name],\n", - " provider_options=[provider_options],\n", - ")\n", - "output = tvm_session.run(output_names, input_dict)[0]\n", - "print(f\"****************** Output shape: {output.shape} ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "b704374b", - "metadata": {}, - "source": [ - "## 4. Support precompiled model\n", - "\n", - "Wrapper functions that allow you to compile the model and save it in the desired format." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "8150942b", - "metadata": {}, - "outputs": [], - "source": [ - "def compile_virtual_machine(model: onnx.ModelProto, target_str: AnyStr) -> tvm.runtime.vm.Executable:\n", - " ir_mod, params = tvm.relay.frontend.from_onnx(\n", - " model,\n", - " opset=model.opset_import[0].version,\n", - " freeze_params=True,\n", - " )\n", - " target = tvm.target.Target(target=target_str, host=target_str)\n", - " return tvm.relay.backend.vm.compile(ir_mod, target)\n", - "\n", - "\n", - "def serialize_virtual_machine(vm_exec: tvm.runtime.vm.Executable) -> AnyStr:\n", - " temp_directory = tempfile.mkdtemp()\n", - " path_consts = os.path.join(temp_directory, \"consts\")\n", - " vm_exec.move_late_bound_consts(path_consts, byte_limit=256)\n", - " lib_path = os.path.join(temp_directory, f\"model.so\")\n", - " code_path = os.path.join(temp_directory, f\"model.ro\")\n", - " code, lib = vm_exec.save()\n", - " lib.export_library(lib_path)\n", - " with open(code_path, \"wb\") as fo:\n", - " fo.write(code)\n", - " return temp_directory" - ] - }, - { - "cell_type": "markdown", - "id": "9cbb987e", - "metadata": {}, - "source": [ - "Preparation of the ONNX model." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "febb9d72", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "onnx_model = get_onnx_model(model_name)\n", - "input_dict = {\n", - " input_name: input_value for input_name, input_value in zip(\n", - " get_onnx_input_names(onnx_model),\n", - " get_random_model_inputs(onnx_model),\n", - " )\n", - "}\n", - "output_names = get_onnx_output_names(onnx_model)" - ] - }, - { - "cell_type": "markdown", - "id": "b05b251a", - "metadata": {}, - "source": [ - "Compiling the ONNX model using `VirtualMachine` (TVM)." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b4b999ee", - "metadata": {}, - "outputs": [], - "source": [ - "compiled_vm_exec = compile_virtual_machine(onnx_model, target_str=\"llvm\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "e3408c15", - "metadata": {}, - "outputs": [], - "source": [ - "so_folder = serialize_virtual_machine(compiled_vm_exec)" - ] - }, - { - "cell_type": "markdown", - "id": "311405e8", - "metadata": {}, - "source": [ - "Preparing `ProviderOptions` and launching `TVM EP` inference.\n", - "\n", - "In order to use the precompiled model, you only need to pass two options:\n", - "* **executor** - `vm` (`VirtualMachine`) must be used as a value (this functionality is not supported for `GraphExecutor`);\n", - "* **so_folder** - as a value, you must pass the path to the directory where the files of the precompiled model are located." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "8927293c", - "metadata": {}, - "outputs": [], - "source": [ - "provider_name = \"TvmExecutionProvider\"\n", - "provider_options = dict(\n", - " executor=\"vm\",\n", - " so_folder=so_folder,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "d7532863", - "metadata": {}, - "outputs": [], - "source": [ - "tvm_session = onnxruntime.InferenceSession(\n", - " onnx_model.SerializeToString(),\n", - " providers=[provider_name],\n", - " provider_options=[provider_options],\n", - ")\n", - "tvm_output = tvm_session.run(output_names, input_dict)" - ] - }, - { - "cell_type": "markdown", - "id": "1c0b983e", - "metadata": {}, - "source": [ - "Let's make sure that the output values match those that can be obtained through `CPUExecutionProvider`:" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "c3de2299", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n" - ] - } - ], - "source": [ - "verify_outputs(\n", - " tvm_output[0],\n", - " get_cpu_onnxruntime_output(\n", - " onnx_model,\n", - " input_dict.values()\n", - " ),\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/include/onnxruntime/core/common/logging/logging.h b/include/onnxruntime/core/common/logging/logging.h index 9cdf42e222051..ab2c476f2975a 100644 --- a/include/onnxruntime/core/common/logging/logging.h +++ b/include/onnxruntime/core/common/logging/logging.h @@ -17,7 +17,6 @@ #include "core/common/logging/macros.h" #include "core/common/logging/severity.h" #include "core/common/logging/sink_types.h" -#include "core/platform/ort_mutex.h" #include "date/date.h" /* @@ -259,7 +258,7 @@ class LoggingManager final { std::unique_ptr sink_; #ifdef _WIN32 - mutable OrtMutex sink_mutex_; + mutable std::mutex sink_mutex_; #endif Severity default_min_severity_; const bool default_filter_user_data_; diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index a5b5d2edde46c..0d9e6db1a7748 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -214,6 +214,14 @@ class IExecutionProvider { return Status::OK(); } + /** + Called when InferenceSession::SetEpDynamicOptions is called + */ + virtual common::Status SetEpDynamicOptions(gsl::span /*keys*/, + gsl::span /*values*/) { + return Status::OK(); + } + /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for the provider. diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 7b3d04ee66d9e..aaf533135429c 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -8,6 +8,9 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { +namespace logging { +class Logger; +} using KernelCreateMap = std::multimap; using KernelDefHashes = std::vector>; @@ -33,6 +36,7 @@ class KernelRegistry { // Kernel matching uses the types from the node and the kernel_type_str_resolver. Status TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const; // map of type constraint name to required type @@ -42,6 +46,7 @@ class KernelRegistry { // Kernel matching uses the explicit type constraint name to required type map in type_constraints. Status TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; /** @@ -61,13 +66,15 @@ class KernelRegistry { std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, - const IKernelTypeStrResolver& kernel_type_str_resolver) { + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) { const KernelCreateInfo* info; - Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); + Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info); return st.IsOK(); } @@ -83,6 +90,7 @@ class KernelRegistry { Status TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; // Check whether the types of inputs/outputs of the given node match the extra diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index f15543f22f21d..6f658ab65be20 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -17,6 +17,7 @@ struct OrtDevice { static const DeviceType GPU = 1; // Nvidia or AMD static const DeviceType FPGA = 2; static const DeviceType NPU = 3; // Ascend + static const DeviceType DML = 4; struct MemType { // Pre-defined memory types. diff --git a/include/onnxruntime/core/graph/schema_registry.h b/include/onnxruntime/core/graph/schema_registry.h index b128e91afa9ae..ca51e3621b2c6 100644 --- a/include/onnxruntime/core/graph/schema_registry.h +++ b/include/onnxruntime/core/graph/schema_registry.h @@ -12,7 +12,6 @@ #include "core/graph/constants.h" #include "core/common/common.h" #include "core/common/status.h" -#include "core/platform/ort_mutex.h" namespace onnxruntime { using OpName_Domain_Version_Schema_Map = std::unordered_map< @@ -102,7 +101,7 @@ class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection { common::Status RegisterOpSchemaInternal(ONNX_NAMESPACE::OpSchema&& op_schema); - OrtMutex mutex_; + std::mutex mutex_; OpName_Domain_Version_Schema_Map map_; DomainToVersionRangeMap domain_version_range_map_; diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6cff153c336f0..31b0f22340510 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -53,6 +53,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); @@ -84,6 +85,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); diff --git a/include/onnxruntime/core/platform/Barrier.h b/include/onnxruntime/core/platform/Barrier.h index 1148b052bd9af..bddc3ba8903f6 100644 --- a/include/onnxruntime/core/platform/Barrier.h +++ b/include/onnxruntime/core/platform/Barrier.h @@ -10,9 +10,9 @@ #include #include "core/common/spin_pause.h" -#include "core/platform/ort_mutex.h" #include +#include #include namespace onnxruntime { @@ -40,7 +40,7 @@ class Barrier { assert(((v + delta) & ~1) != 0); return; // either count has not dropped to 0, or waiter is not waiting } - std::unique_lock l(mu_); + std::unique_lock l(mu_); assert(!notified_); notified_ = true; cv_.notify_all(); @@ -55,7 +55,7 @@ class Barrier { unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); if ((v >> 1) == 0) return; - std::unique_lock l(mu_); + std::unique_lock l(mu_); while (!notified_) { cv_.wait(l); } @@ -63,8 +63,8 @@ class Barrier { } private: - OrtMutex mu_; - OrtCondVar cv_; + std::mutex mu_; + std::condition_variable cv_; std::atomic state_; // low bit is waiter flag bool notified_; const bool spin_; diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index d4411a6d72356..a7c63c507d1ba 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -50,7 +50,6 @@ #include "core/common/denormal.h" #include "core/common/inlined_containers_fwd.h" #include "core/common/spin_pause.h" -#include "core/platform/ort_mutex.h" #include "core/platform/ort_spin_lock.h" #include "core/platform/Barrier.h" @@ -460,7 +459,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); Elem& e = array_[(back - 1) & kMask]; @@ -484,7 +483,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); w_idx = (back - 1) & kMask; @@ -509,7 +508,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back; Elem* e; @@ -555,7 +554,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif Elem& e = array_[w_idx]; ElemState s = e.state.load(std::memory_order_relaxed); @@ -631,7 +630,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE OrtSpinLock spin_lock_; #else - OrtMutex mutex_; + std::mutex mutex_; #endif // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of @@ -1440,7 +1439,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadStatus seen = GetStatus(); if (seen == ThreadStatus::Blocking || seen == ThreadStatus::Blocked) { - std::unique_lock lk(mutex); + std::unique_lock lk(mutex); // Blocking state exists only transiently during the SetBlock() method // while holding the lock. We may observe it at the start of this // function, but after acquiring the lock then the target thread @@ -1468,11 +1467,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter status = ThreadStatus::Spinning; } - void SetBlocked(std::function should_block, + bool SetBlocked(std::function should_block, std::function post_block) { - std::unique_lock lk(mutex); - assert(GetStatus() == ThreadStatus::Spinning); - status.store(ThreadStatus::Blocking, std::memory_order_relaxed); + std::unique_lock lk(mutex); + auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst); + if (old_status != ThreadStatus::Spinning) { + // Encountered a logical error + return false; + } if (should_block()) { status.store(ThreadStatus::Blocked, std::memory_order_relaxed); do { @@ -1481,12 +1483,13 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter post_block(); } status.store(ThreadStatus::Spinning, std::memory_order_relaxed); + return true; } private: std::atomic status{ThreadStatus::Spinning}; - OrtMutex mutex; - OrtCondVar cv; + std::mutex mutex; + std::condition_variable cv; }; Environment& env_; @@ -1559,62 +1562,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Attempt to block if (!t) { - td.SetBlocked( // Pre-block test - [&]() -> bool { - bool should_block = true; - // Check whether work was pushed to us while attempting to block. We make - // this test while holding the per-thread status lock, and after setting - // our status to ThreadStatus::Blocking. - // - // This synchronizes with ThreadPool::Schedule which pushes work to the queue - // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): - // - // Main thread: Worker: - // #1 Push work #A Set status blocking - // #2 Read worker status #B Check queue - // #3 Wake if blocking/blocked - // - // If #A is before #2 then main sees worker blocked and wakes - // - // If #A if after #2 then #B will see #1, and we abandon blocking - assert(!t); - t = q.PopFront(); - if (t) { - should_block = false; - } - - // No work pushed to us, continue attempting to block. The remaining - // test is to synchronize with termination requests. If we are - // shutting down and all worker threads blocked without work, that's - // we are done. - if (should_block) { - blocked_++; - if (done_ && blocked_ == num_threads_) { - should_block = false; - // Almost done, but need to re-check queues. - // Consider that all queues are empty and all worker threads are preempted - // right after incrementing blocked_ above. Now a free-standing thread - // submits work and calls destructor (which sets done_). If we don't - // re-check queues, we will exit leaving the work unexecuted. - if (NonEmptyQueueIndex() != -1) { - // Note: we must not pop from queues before we decrement blocked_, - // otherwise the following scenario is possible. Consider that instead - // of checking for emptiness we popped the only element from queues. - // Now other worker threads can start exiting, which is bad if the - // work item submits other work. So we just check emptiness here, - // which ensures that all worker threads exit at the same time. - blocked_--; - } else { - should_exit = true; + if (!td.SetBlocked( // Pre-block test + [&]() -> bool { + bool should_block = true; + // Check whether work was pushed to us while attempting to block. We make + // this test while holding the per-thread status lock, and after setting + // our status to ThreadStatus::Blocking. + // + // This synchronizes with ThreadPool::Schedule which pushes work to the queue + // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): + // + // Main thread: Worker: + // #1 Push work #A Set status blocking + // #2 Read worker status #B Check queue + // #3 Wake if blocking/blocked + // + // If #A is before #2 then main sees worker blocked and wakes + // + // If #A if after #2 then #B will see #1, and we abandon blocking + assert(!t); + t = q.PopFront(); + if (t) { + should_block = false; + } + + // No work pushed to us, continue attempting to block. The remaining + // test is to synchronize with termination requests. If we are + // shutting down and all worker threads blocked without work, that's + // we are done. + if (should_block) { + blocked_++; + if (done_ && blocked_ == num_threads_) { + should_block = false; + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + } else { + should_exit = true; + } + } } - } - } - return should_block; - }, - // Post-block update (executed only if we blocked) - [&]() { - blocked_--; - }); + return should_block; + }, + // Post-block update (executed only if we blocked) + [&]() { + blocked_--; + })) { + // Encountered a fatal logic error in SetBlocked + should_exit = true; + break; + } // Thread just unblocked. Unless we picked up work while // blocking, or are exiting, then either work was pushed to // us, or it was pushed to an overloaded queue diff --git a/include/onnxruntime/core/platform/ort_mutex.h b/include/onnxruntime/core/platform/ort_mutex.h deleted file mode 100644 index e24665f51423d..0000000000000 --- a/include/onnxruntime/core/platform/ort_mutex.h +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#ifdef _WIN32 -#include -#include -namespace onnxruntime { -// Q: Why OrtMutex is better than std::mutex -// A: OrtMutex supports static initialization but std::mutex doesn't. Static initialization helps us prevent the "static -// initialization order problem". - -// Q: Why std::mutex can't make it? -// A: VC runtime has to support Windows XP at ABI level. But we don't have such requirement. - -// Q: Is OrtMutex faster than std::mutex? -// A: Sure - -class OrtMutex { - private: - SRWLOCK data_ = SRWLOCK_INIT; - - public: - constexpr OrtMutex() = default; - // SRW locks do not need to be explicitly destroyed. - ~OrtMutex() = default; - OrtMutex(const OrtMutex&) = delete; - OrtMutex& operator=(const OrtMutex&) = delete; - void lock() { AcquireSRWLockExclusive(native_handle()); } - bool try_lock() noexcept { return TryAcquireSRWLockExclusive(native_handle()) == TRUE; } - void unlock() noexcept { ReleaseSRWLockExclusive(native_handle()); } - using native_handle_type = SRWLOCK*; - - __forceinline native_handle_type native_handle() { return &data_; } -}; - -class OrtCondVar { - CONDITION_VARIABLE native_cv_object = CONDITION_VARIABLE_INIT; - - public: - constexpr OrtCondVar() noexcept = default; - ~OrtCondVar() = default; - - OrtCondVar(const OrtCondVar&) = delete; - OrtCondVar& operator=(const OrtCondVar&) = delete; - - void notify_one() noexcept { WakeConditionVariable(&native_cv_object); } - void notify_all() noexcept { WakeAllConditionVariable(&native_cv_object); } - - void wait(std::unique_lock& lk) { - if (SleepConditionVariableSRW(&native_cv_object, lk.mutex()->native_handle(), INFINITE, 0) != TRUE) { - std::terminate(); - } - } - template - void wait(std::unique_lock& __lk, _Predicate __pred); - - /** - * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns - * cv_status::no_timeout. - * @param cond_mutex A unique_lock object. - * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up. - * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns - * cv_status::no_timeout - */ - template - std::cv_status wait_for(std::unique_lock& cond_mutex, const std::chrono::duration& rel_time); - using native_handle_type = CONDITION_VARIABLE*; - - native_handle_type native_handle() { return &native_cv_object; } - - private: - void timed_wait_impl(std::unique_lock& __lk, - std::chrono::time_point); -}; - -template -void OrtCondVar::wait(std::unique_lock& __lk, _Predicate __pred) { - while (!__pred()) wait(__lk); -} - -template -std::cv_status OrtCondVar::wait_for(std::unique_lock& cond_mutex, - const std::chrono::duration& rel_time) { - // TODO: is it possible to use nsync_from_time_point_ ? - using namespace std::chrono; - if (rel_time <= duration::zero()) - return std::cv_status::timeout; - using SystemTimePointFloat = time_point >; - using SystemTimePoint = time_point; - SystemTimePointFloat max_time = SystemTimePoint::max(); - steady_clock::time_point steady_now = steady_clock::now(); - system_clock::time_point system_now = system_clock::now(); - if (max_time - rel_time > system_now) { - nanoseconds remain = duration_cast(rel_time); - if (remain < rel_time) - ++remain; - timed_wait_impl(cond_mutex, system_now + remain); - } else - timed_wait_impl(cond_mutex, SystemTimePoint::max()); - return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout; -} -} // namespace onnxruntime -#else -#include "nsync.h" -#include //for unique_lock -#include //for cv_status -namespace onnxruntime { - -class OrtMutex { - nsync::nsync_mu data_ = NSYNC_MU_INIT; - - public: - constexpr OrtMutex() = default; - ~OrtMutex() = default; - OrtMutex(const OrtMutex&) = delete; - OrtMutex& operator=(const OrtMutex&) = delete; - - void lock() { nsync::nsync_mu_lock(&data_); } - bool try_lock() noexcept { return nsync::nsync_mu_trylock(&data_) == 0; } - void unlock() noexcept { nsync::nsync_mu_unlock(&data_); } - - using native_handle_type = nsync::nsync_mu*; - native_handle_type native_handle() { return &data_; } -}; - -class OrtCondVar { - nsync::nsync_cv native_cv_object = NSYNC_CV_INIT; - - public: - constexpr OrtCondVar() noexcept = default; - - ~OrtCondVar() = default; - OrtCondVar(const OrtCondVar&) = delete; - OrtCondVar& operator=(const OrtCondVar&) = delete; - - void notify_one() noexcept { nsync::nsync_cv_signal(&native_cv_object); } - void notify_all() noexcept { nsync::nsync_cv_broadcast(&native_cv_object); } - - void wait(std::unique_lock& lk); - template - void wait(std::unique_lock& __lk, _Predicate __pred); - - /** - * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns - * cv_status::no_timeout. - * @param cond_mutex A unique_lock object. - * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up. - * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns - * cv_status::no_timeout - */ - template - std::cv_status wait_for(std::unique_lock& cond_mutex, const std::chrono::duration& rel_time); - using native_handle_type = nsync::nsync_cv*; - native_handle_type native_handle() { return &native_cv_object; } - - private: - void timed_wait_impl(std::unique_lock& __lk, - std::chrono::time_point); -}; - -template -void OrtCondVar::wait(std::unique_lock& __lk, _Predicate __pred) { - while (!__pred()) wait(__lk); -} - -template -std::cv_status OrtCondVar::wait_for(std::unique_lock& cond_mutex, - const std::chrono::duration& rel_time) { - // TODO: is it possible to use nsync_from_time_point_ ? - using namespace std::chrono; - if (rel_time <= duration::zero()) - return std::cv_status::timeout; - using SystemTimePointFloat = time_point >; - using SystemTimePoint = time_point; - SystemTimePointFloat max_time = SystemTimePoint::max(); - steady_clock::time_point steady_now = steady_clock::now(); - system_clock::time_point system_now = system_clock::now(); - if (max_time - rel_time > system_now) { - nanoseconds remain = duration_cast(rel_time); - if (remain < rel_time) - ++remain; - timed_wait_impl(cond_mutex, system_now + remain); - } else - timed_wait_impl(cond_mutex, SystemTimePoint::max()); - return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout; -} -}; // namespace onnxruntime -#endif diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 55abb90b981f5..d035fd34bd072 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -31,11 +31,37 @@ enum COREMLFlags { // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc + // there are four compute units: + // MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll + // different CU will have different performance and power consumption + COREML_FLAG_USE_CPU_AND_GPU = 0x020, // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, + COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU, }; +// MLComputeUnits can be one of the following values: +// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll' +// these values are intended to be used with Ort::SessionOptions::AppendExecutionProvider (C++ API) +// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead. +static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits"; +static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat"; +// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES +static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes"; +static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs"; +// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property +// Core ML segments the model’s compute graph and specializes each segment for the target compute device. +// This process can affect the model loading time and the prediction latency. +// Use this option to tailor the specialization strategy for your model. +static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy"; +// Profile the Core ML MLComputePlan. +// This logs the hardware each operator is dispatched to and the estimated execution time. +// Intended for developer usage but provide useful diagnostic information if performance is not as expected. +static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan"; +// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu +static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU"; + #ifdef __cplusplus extern "C" { #endif diff --git a/include/onnxruntime/core/providers/rocm/rocm_context.h b/include/onnxruntime/core/providers/rocm/rocm_context.h index f187e0cbb3a89..aad1736217129 100644 --- a/include/onnxruntime/core/providers/rocm/rocm_context.h +++ b/include/onnxruntime/core/providers/rocm/rocm_context.h @@ -7,7 +7,7 @@ #include "core/providers/custom_op_context.h" #include #include -#include +#include namespace Ort { @@ -16,7 +16,7 @@ namespace Custom { struct RocmContext : public CustomOpContext { hipStream_t hip_stream = {}; miopenHandle_t miopen_handle = {}; - rocblas_handle rblas_handle = {}; + hipblasHandle_t blas_handle = {}; void Init(const OrtKernelContext& kernel_ctx) { const auto& ort_api = Ort::GetApi(); @@ -40,11 +40,11 @@ struct RocmContext : public CustomOpContext { resource = {}; status = ort_api.KernelContext_GetResource( - &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::rocblas_handle_t, &resource); + &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hipblas_handle_t, &resource); if (status) { - ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + ORT_CXX_API_THROW("failed to fetch hipblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } - rblas_handle = reinterpret_cast(resource); + blas_handle = reinterpret_cast(resource); } }; diff --git a/include/onnxruntime/core/providers/rocm/rocm_resource.h b/include/onnxruntime/core/providers/rocm/rocm_resource.h index 772447a1809d8..db032b48714c3 100644 --- a/include/onnxruntime/core/providers/rocm/rocm_resource.h +++ b/include/onnxruntime/core/providers/rocm/rocm_resource.h @@ -8,5 +8,9 @@ enum RocmResource : int { hip_stream_t = rocm_resource_offset, miopen_handle_t, - rocblas_handle_t + hipblas_handle_t, + deferred_cpu_allocator_t, + // below are rocm ep options + device_id_t, // 10004 + arena_extend_strategy_t }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0348123ab7acb..a35d975ac8f1b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -38,7 +38,7 @@ * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 20 +#define ORT_API_VERSION 21 #ifdef __cplusplus extern "C" { @@ -626,8 +626,13 @@ typedef struct OrtMIGraphXProviderOptions { } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + * \brief This Struct is frozen since ORT 1.13.0. Its maintained part of Legacy API for compatibility. + * \brief For latest OpenVINO Provider Options update to the ProviderOptions map. + * \brief Latest OpenVINO Provider Options are listed in the + * \htmlonly + *
onnxruntime document. + * \endhtmlonly + * \see OrtApi::SessionOptionsAppendExecutionProvider() */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus @@ -3651,13 +3656,20 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Used for float32 model for HTP backend. - Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": With fp32 precision. - - "1": Default. With fp16 precision. - "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - - "0": Default. Disabled. - - "1": Enabled. + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * - "0": With fp32 precision. + * - "1": Default. With fp16 precision. + * "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. + * - "0": Default. Disabled. + * - "1": Enabled. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * - "1": Enabled. + * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary. + * - "0": Default. Disabled. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -3783,7 +3795,7 @@ struct OrtApi { /** \brief Release an OrtCANNProviderOptions * - * \param[in] the pointer of OrtCANNProviderOptions which will been deleted + * \param[in] input The pointer of OrtCANNProviderOptions which will been deleted * * \since Version 1.13. */ @@ -4603,6 +4615,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_ OrtSessionOptions* options, @@ -4620,6 +4634,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, @@ -4633,7 +4649,10 @@ struct OrtApi { * \param[in] mem_info OrtMemoryInfo instance * \param[in] count_or_bytes How many bytes is this scratch buffer * \param[out] out A pointer to the scrach buffer + * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); @@ -4644,6 +4663,8 @@ struct OrtApi { * \param[out] out A pointer to OrtAllocator * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); @@ -4665,6 +4686,8 @@ struct OrtApi { * \param[in] num_external_initializer_files Number of external files * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, @@ -4685,6 +4708,10 @@ struct OrtApi { * The data would still be copied to device if required by the model at inference time. * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4701,6 +4728,10 @@ struct OrtApi { * The data would still be copied to device if required by the model at inference time. * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4720,8 +4751,33 @@ struct OrtApi { * * \param[in] options OrtRunOptions instance * \param[in] adapter OrtLoraAdapter instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); + + /// @} + /// \name OrtEpDynamicOptions + /// @{ + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] sess OrtSession + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b4911d515d071..f3e9758766d00 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -650,6 +650,9 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception @@ -728,6 +731,9 @@ struct Env : detail::Base { * */ struct CustomOpDomain : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateCustomOpDomain @@ -963,8 +969,10 @@ struct SessionOptions : detail::SessionOptionsImpl { * */ struct ModelMetadata : detail::Base { - explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used - explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API + using Base = detail::Base; + using Base::Base; + + explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used /** \brief Returns a copy of the producer name. * @@ -1140,6 +1148,19 @@ struct SessionImpl : ConstSessionImpl { * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Wraps OrtApi::SetEpDynamicOptions + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + */ + void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); }; } // namespace detail @@ -1224,6 +1245,9 @@ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl { + using Base = detail::TensorTypeAndShapeInfoImpl; + using Base::Base; + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } @@ -1245,6 +1269,9 @@ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl { + using Base = detail::SequenceTypeInfoImpl; + using Base::Base; + explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } @@ -1280,6 +1307,9 @@ using ConstMapTypeInfo = detail::MapTypeInfoImpl { + using Base = detail::MapTypeInfoImpl; + using Base::Base; + explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } @@ -1311,6 +1341,9 @@ using ConstTypeInfo = detail::TypeInfoImpl>; /// the information about contained sequence or map depending on the ONNXType. ///
struct TypeInfo : detail::TypeInfoImpl { + using Base = detail::TypeInfoImpl; + using Base::Base; + explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop @@ -1648,11 +1681,11 @@ using UnownedValue = detail::ValueImpl>; */ struct Value : detail::ValueImpl { using Base = detail::ValueImpl; + using Base::Base; using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used - explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API + explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used Value(Value&&) = default; Value& operator=(Value&&) = default; @@ -1928,6 +1961,10 @@ struct ArenaCfg : detail::Base { /// This struct provides life time management for custom op attribute ///
struct OpAttr : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit OpAttr(std::nullptr_t) {} OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); }; @@ -2170,6 +2207,8 @@ using ConstKernelInfo = detail::KernelInfoImpl struct KernelInfo : detail::KernelInfoImpl { + using Base = detail::KernelInfoImpl; + using Base::Base; explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } @@ -2179,6 +2218,9 @@ struct KernelInfo : detail::KernelInfoImpl { /// Create and own custom defined operation. ///
struct Op : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used explicit Op(OrtOp*); ///< Take ownership of the OrtOp diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 7401cb2438121..3aeb9412f350e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -51,7 +51,7 @@ inline void ThrowOnError(const Status& st) { } } -inline Status::Status(OrtStatus* status) noexcept : Base{status} { +inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { } inline Status::Status(const std::exception& e) noexcept { @@ -1093,6 +1093,11 @@ inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* al return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } +template +inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len) { + ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); +} + } // namespace detail inline SessionOptions::SessionOptions() { @@ -1903,7 +1908,7 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std:: inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} -inline Op::Op(OrtOp* p) : Base(p) {} +inline Op::Op(OrtOp* p) : detail::Base(p) {} inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, const char** type_constraint_names, diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 9942f8c656760..c80b8c0c164b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -49,8 +49,3 @@ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_con // If the value is set to -1, cuda graph capture/replay is disabled in that run. // User are not expected to set the value to 0 as it is reserved for internal use. static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; - -// Specify the type of workload for this run. -// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] -// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. -static const char* const kOrtRunOptionsWorkloadType = "run.workload_type"; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index b0539b78a69d1..8f1bc98ce7b49 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -261,8 +261,8 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; // Flag to specify whether to dump the EP context into the Onnx model. -// "0": dump the EP context into separate file, keep the file name in the Onnx model. -// "1": dump the EP context into the Onnx model. (default). +// "0": dump the EP context into separate file, keep the file name in the Onnx model. (default). +// "1": dump the EP context into the Onnx model. static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; // Specify the EPContext node name prefix to make it unique @@ -283,7 +283,9 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. // “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] // “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. -static const char* const kOrtSessionOptionsWorkloadType = "session.workload_type"; +static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; diff --git a/java/build-android.gradle b/java/build-android.gradle index fd22fa27e8db9..9c4275b74f626 100644 --- a/java/build-android.gradle +++ b/java/build-android.gradle @@ -8,25 +8,47 @@ def publishDir = System.properties['publishDir'] def minSdkVer = System.properties['minSdkVer'] def targetSdkVer = System.properties['targetSdkVer'] boolean enableTrainingApis = (System.properties['ENABLE_TRAINING_APIS'] ?: "0") == "1" - -// Since Android requires a higher numbers indicating more recent versions -// This function assume ORT version number will be in formart of A.B.C such as 1.7.0 -// We generate version code A[0{0,1}]B[0{0,1}]C, -// for example '1.7.0' -> 10700, '1.6.15' -> 10615 -def getVersionCode(String version){ - String[] codes = version.split('\\.'); +def releaseVersionSuffix = System.properties['releaseVersionSuffix'] ?: "" +// Expected format for qnnVersion: major.minor.patch (e.g., 2.26.0) +// QNN package version does not follow Semantic Versioning (SemVer) format. +// For non qnn builds, qnnVersion will be null +def qnnVersion = System.properties['qnnVersion'] + +// Since Android requires higher numbers indicating more recent versions +// This function assumes ORT version number will be in the format of A.B.C[-rc/beta/alpha.D] such as 1.20.0 or 1.20.0-rc.1 +// We generate version code A[0{0,1}]B[0{0,1}]C[0{0,1}]{1,2,3,4}D[01-99] +// for example '1.20.0' -> 12000400, '1.20.0-rc.1 ' -> 12000301 +// '1.20.0-beta.1' -> 12000201, '1.20.0-alpha.1' -> 12000101 +def getVersionCode(String version) { + String[] versionAndRelSufx = version.split('-') + String[] codes = versionAndRelSufx[0].split('\\.') // This will have problem if we have 3 digit [sub]version number, such as 1.7.199 // but it is highly unlikely to happen - String versionCodeStr = String.format("%d%02d%02d", codes[0] as int, codes[1] as int, codes[2] as int); - return versionCodeStr as int; + String versionCodeStr = String.format("%d%02d%02d", codes[0] as int, codes[1] as int, codes[2] as int) + + if (versionAndRelSufx.length > 1) { + String suffixType = versionAndRelSufx[1].split('\\.')[0] + String suffixNumber = versionAndRelSufx[1].split('\\.')[1] + def suffixMap = ['alpha': '1', 'beta': '2', 'rc': '3'] + versionCodeStr += suffixMap[suffixType] + String.format("%02d", suffixNumber as int) + } else { + versionCodeStr += "400" // For a normal release version without suffix, get the highest version code + } + println "Version code for $version is $versionCodeStr" + return versionCodeStr as int } project.buildDir = buildDir -project.version = rootProject.file('../VERSION_NUMBER').text.trim() +def project_version = rootProject.file('../VERSION_NUMBER').text.trim() +project.version = releaseVersionSuffix ? "${project_version}${releaseVersionSuffix}" : project_version project.group = "com.microsoft.onnxruntime" def tmpArtifactId = enableTrainingApis ? project.name + "-training" : project.name -def mavenArtifactId = tmpArtifactId + '-android' +def mavenArtifactId = tmpArtifactId + '-android' + (qnnVersion != null ? '-qnn' : '') + +//should the mavenArtifactId be read from the packageName variable as +//that's how it's used in the build_aar_copy_artifacts.sh while copying the artifacts + def defaultDescription = 'ONNX Runtime is a performance-focused inference engine for ONNX (Open Neural Network ' + 'Exchange) models. This package contains the Android (aar) build of ONNX Runtime. It includes support for all ' + 'types and operators, for ONNX format models. All standard ONNX models can be executed with this package.' @@ -34,6 +56,10 @@ def trainingDescription = 'The onnxruntime-training android package is designed 'wide range of ONNX models on edge devices, such as mobile phones, tablets, and other portable devices with ' + 'a focus on minimizing resource usage and maximizing accuracy.' + 'See https://github.com/microsoft/onnxruntime-training-examples/tree/master/on_device_training for more details.' +def qnnDescription = 'ONNX Runtime is a performance-focused inference engine for ONNX (Open Neural Network ' + + 'Exchange) models. This package contains the Android (aar) build of ONNX Runtime with the QNN Execution Provider.' + + 'It includes support for all types and operators, for ONNX format models. All standard ONNX models can be executed' + + 'with this package.' buildscript { repositories { @@ -56,7 +82,7 @@ allprojects { } android { - compileSdkVersion 32 + compileSdkVersion 34 defaultConfig { minSdkVersion minSdkVer @@ -82,8 +108,8 @@ android { } compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } sourceSets { @@ -137,8 +163,9 @@ publishing { artifact sourcesJar pom { - name = enableTrainingApis ? 'onnxruntime-training' : 'onnx-runtime' - description = enableTrainingApis ? trainingDescription : defaultDescription + name = qnnVersion != null ? 'onnxruntime-qnn' : (enableTrainingApis ? 'onnxruntime-training' : 'onnx-runtime') + description = qnnVersion != null ? qnnDescription : (enableTrainingApis ? trainingDescription : defaultDescription) + url = 'https://microsoft.github.io/onnxruntime/' licenses { license { @@ -162,6 +189,16 @@ publishing { email = 'onnxruntime@microsoft.com' } } + + if (qnnVersion != null) { + println "Modifying the POM XML to include QNN dependency" + withXml { + def dependencynode = asNode().appendNode('dependencies').appendNode('dependency') + dependencynode.appendNode('groupId', 'com.qualcomm.qti') + dependencynode.appendNode('artifactId', 'qnn-runtime') + dependencynode.appendNode('version', qnnVersion) + } + } } } } diff --git a/java/build.gradle b/java/build.gradle index 34ac93cce6f4e..845121dd17a48 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -50,8 +50,8 @@ mavenSettings { } java { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } // This jar tasks serves as a CMAKE signaling diff --git a/java/gradle/wrapper/gradle-wrapper.properties b/java/gradle/wrapper/gradle-wrapper.properties index 4baf5a11d45a3..381baa9cef1ec 100644 --- a/java/gradle/wrapper/gradle-wrapper.properties +++ b/java/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=9631d53cf3e74bfa726893aee1f8994fee4e060c401335946dba2156f440f24c -distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d +distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/java/gradlew.bat b/java/gradlew.bat index 93e3f59f135dd..25da30dbdeee9 100644 --- a/java/gradlew.bat +++ b/java/gradlew.bat @@ -43,11 +43,11 @@ set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -57,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 7280f3c88e2e8..32dc9d9f84aaa 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1323,6 +1323,18 @@ public void addQnn(Map providerOptions) throws OrtException { addExecutionProvider(qnnProviderName, providerOptions); } + /** + * Adds CoreML as an execution backend. + * + * @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML + * execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addCoreML(Map providerOptions) throws OrtException { + String CoreMLProviderName = "CoreML"; + addExecutionProvider(CoreMLProviderName, providerOptions); + } + private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index cec3fadf446ca..22bf940844774 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -25,7 +25,9 @@ public enum CoreMLFlags implements OrtFlags { * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or * later. */ - CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) + CREATE_MLPROGRAM(16), // COREML_FLAG_CREATE_MLPROGRAM(0x010) + /** exclude ANE */ + CPU_AND_GPU(32); // COREML_FLAG_USE_CPU_AND_GPU(0x020) /** The native value of the enum. */ public final int value; diff --git a/java/src/test/android/README.md b/java/src/test/android/README.md index b84021669c9fe..b086be3dc904c 100644 --- a/java/src/test/android/README.md +++ b/java/src/test/android/README.md @@ -29,6 +29,11 @@ Use the android's [build instructions](https://onnxruntime.ai/docs/build/android Please note that you may need to set the `--android_abi=x86_64` (the default option is `arm64-v8a`). This is because android instrumentation test is run on an android emulator which requires an abi of `x86_64`. +#### QNN Builds +We use two AndroidManifest.xml files to manage different runtime requirements for QNN support. In the [build configuration](app/build.gradle), we specify which manifest file to use based on the qnnVersion. +In the [QNN manifest](app/src/main/AndroidManifestQnn.xml), we include the declaration for libcdsprpc.so, which is required for devices using QNN and Qualcomm DSP capabilities. +For QNN builds, it is also necessary to set the `ADSP_LIBRARY_PATH` environment variable to the [native library directory](https://developer.android.com/reference/android/content/pm/ApplicationInfo#nativeLibraryDir) depending on the device. This ensures that any native libraries downloaded as dependencies such as QNN libraries are found by the application. This is conditionally added by using the BuildConfig field IS_QNN_BUILD set in the build.gradle file. + #### Build Output The build will generate two apks which is required to run the test application in `$YOUR_BUILD_DIR/java/androidtest/android/app/build/outputs/apk`: diff --git a/java/src/test/android/app/build.gradle b/java/src/test/android/app/build.gradle index 381de06cc09de..baf18e714d25c 100644 --- a/java/src/test/android/app/build.gradle +++ b/java/src/test/android/app/build.gradle @@ -4,18 +4,27 @@ plugins { } def minSdkVer = System.properties.get("minSdkVer")?:24 +def qnnVersion = System.properties['qnnVersion'] android { - compileSdkVersion 32 + compileSdkVersion 34 defaultConfig { applicationId "ai.onnxruntime.example.javavalidator" minSdkVersion minSdkVer - targetSdkVersion 32 + targetSdkVersion 34 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + // Add BuildConfig field for qnnVersion + if (qnnVersion != null) { + buildConfigField "boolean", "IS_QNN_BUILD", "true" + } + else { + buildConfigField "boolean", "IS_QNN_BUILD", "false" + } } buildTypes { @@ -25,11 +34,29 @@ android { } } compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + sourceCompatibility JavaVersion.VERSION_17 + targetCompatibility JavaVersion.VERSION_17 } kotlinOptions { - jvmTarget = '1.8' + jvmTarget = '17' + } + // Conditional packagingOptions for QNN builds only + if (qnnVersion != null) { + packagingOptions { + jniLibs { + useLegacyPackaging = true + } + // Dsp is used in older QC devices and not supported by ORT + // Gpu support isn't the target, we just want Npu support (Htp) + exclude 'lib/arm64-v8a/libQnnGpu.so' + exclude 'lib/arm64-v8a/libQnnDsp*.so' + } + + sourceSets { + main { + manifest.srcFile 'src/main/AndroidManifestQnn.xml' // Use QNN manifest + } + } } namespace 'ai.onnxruntime.example.javavalidator' } @@ -42,11 +69,20 @@ dependencies { implementation 'com.google.android.material:material:1.3.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' testImplementation 'junit:junit:4.+' - androidTestImplementation 'androidx.test.ext:junit:1.1.3' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' - implementation(name: "onnxruntime-android", ext: "aar") + androidTestImplementation "androidx.test.ext:junit:1.1.5" + androidTestImplementation "androidx.test.espresso:espresso-core:3.5.0" - androidTestImplementation 'androidx.test:runner:1.4.0' - androidTestImplementation 'androidx.test:rules:1.4.0' + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" androidTestImplementation 'com.microsoft.appcenter:espresso-test-extension:1.4' + + // dependencies for onnxruntime-android-qnn + if (qnnVersion != null) { + implementation(name: "onnxruntime-android-qnn", ext: "aar") + implementation "com.qualcomm.qti:qnn-runtime:$qnnVersion" + } + else { + implementation(name: "onnxruntime-android", ext: "aar") + } + } diff --git a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt index 166803ae263a5..5e6bee6cac9f4 100644 --- a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt +++ b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt @@ -38,13 +38,18 @@ class SimpleTest { @Test fun runSigmoidModelTest() { for (intraOpNumThreads in 1..4) { - runSigmoidModelTestImpl(intraOpNumThreads) + runSigmoidModelTestImpl(intraOpNumThreads, OrtProvider.CPU) } } @Test fun runSigmoidModelTestNNAPI() { - runSigmoidModelTestImpl(1, true) + runSigmoidModelTestImpl(1, OrtProvider.NNAPI) + } + + @Test + fun runSigmoidModelTestQNN() { + runSigmoidModelTestImpl(1, OrtProvider.QNN) } @Throws(IOException::class) @@ -54,22 +59,49 @@ class SimpleTest { } @Throws(OrtException::class, IOException::class) - fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) { - reportHelper.label("Start Running Test with intraOpNumThreads=$intraOpNumThreads, useNNAPI=$useNNAPI") + fun runSigmoidModelTestImpl(intraOpNumThreads: Int, executionProvider: OrtProvider) { + reportHelper.label("Start Running Test with intraOpNumThreads=$intraOpNumThreads, executionProvider=$executionProvider") Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads") - Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI") + Log.println(Log.INFO, TAG, "Testing with executionProvider=$executionProvider") + val env = OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE) env.use { val opts = SessionOptions() opts.setIntraOpNumThreads(intraOpNumThreads) - if (useNNAPI) { - if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) { - opts.addNnapi() - } else { - Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test") - return + + when (executionProvider) { + + OrtProvider.NNAPI -> { + if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) { + opts.addNnapi() + } else { + Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test") + return + } + } + + OrtProvider.QNN -> { + if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.QNN)) { + // Since this is running in an Android environment, we use the .so library + val qnnLibrary = "libQnnHtp.so" + val providerOptions = Collections.singletonMap("backend_path", qnnLibrary) + opts.addQnn(providerOptions) + } else { + Log.println(Log.INFO, TAG, "NO QNN EP available, skip the test") + return + } + } + + OrtProvider.CPU -> { + // No additional configuration is needed for CPU + } + + else -> { + // Non exhaustive when statements on enum will be prohibited in future Gradle versions + Log.println(Log.INFO, TAG, "Skipping test as OrtProvider is not implemented") } } + opts.use { val session = env.createSession(readModel("sigmoid.ort"), opts) session.use { @@ -92,13 +124,15 @@ class SimpleTest { output.use { @Suppress("UNCHECKED_CAST") val rawOutput = output[0].value as Array> + // QNN EP will run the Sigmoid float32 op with fp16 precision + val precision = if (executionProvider == OrtProvider.QNN) 1e-3 else 1e-6 for (i in 0..2) { for (j in 0..3) { for (k in 0..4) { Assert.assertEquals( rawOutput[i][j][k], expected[i][j][k], - 1e-6.toFloat() + precision.toFloat() ) } } diff --git a/java/src/test/android/app/src/main/AndroidManifest.xml b/java/src/test/android/app/src/main/AndroidManifest.xml index 2938b7e8bf409..08a612ed79fd6 100644 --- a/java/src/test/android/app/src/main/AndroidManifest.xml +++ b/java/src/test/android/app/src/main/AndroidManifest.xml @@ -17,4 +17,4 @@ - \ No newline at end of file + diff --git a/java/src/test/android/app/src/main/AndroidManifestQnn.xml b/java/src/test/android/app/src/main/AndroidManifestQnn.xml new file mode 100644 index 0000000000000..c9416523a9c91 --- /dev/null +++ b/java/src/test/android/app/src/main/AndroidManifestQnn.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + diff --git a/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt b/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt index 62e23c4b9b862..3b3a2d057b16e 100644 --- a/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt +++ b/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt @@ -1,11 +1,19 @@ package ai.onnxruntime.example.javavalidator import android.os.Bundle +import android.system.Os import androidx.appcompat.app.AppCompatActivity /*Empty activity app mainly used for testing*/ class MainActivity : AppCompatActivity() { override fun onCreate(savedInstanceState: Bundle?) { + if (BuildConfig.IS_QNN_BUILD) { + val adspLibraryPath = applicationContext.applicationInfo.nativeLibraryDir + // set the path variable to the native library directory + // so that any native libraries downloaded as dependencies + // (like qnn libs) are found + Os.setenv("ADSP_LIBRARY_PATH", adspLibraryPath, true) + } super.onCreate(savedInstanceState) } -} \ No newline at end of file +} diff --git a/js/.eslintrc.js b/js/.eslintrc.js index bd1e9061355f5..462e417df1d66 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -198,19 +198,6 @@ module.exports = { '_OrtReleaseTensor', '_OrtRun', '_OrtRunWithBinding', - '_OrtTrainingCopyParametersFromBuffer', - '_OrtTrainingCopyParametersToBuffer', - '_OrtTrainingCreateSession', - '_OrtTrainingEvalStep', - '_OrtTrainingGetModelInputOutputCount', - '_OrtTrainingGetModelInputOutputName', - '_OrtTrainingGetParametersSize', - '_OrtTrainingLazyResetGrad', - '_OrtTrainingLoadCheckpoint', - '_OrtTrainingOptimizerStep', - '_OrtTrainingReleaseCheckpoint', - '_OrtTrainingReleaseSession', - '_OrtTrainingRunTrainStep', ], }, ], diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index e27e67622aa82..e63f9c6c9147f 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,7 +3,6 @@ import { InferenceSession } from './inference-session.js'; import { OnnxValue } from './onnx-value.js'; -import { TrainingSession } from './training-session.js'; /** * @ignore @@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler { ): Promise; } -/** - * Represent a handler instance of a training inference session. - * - * @ignore - */ -export interface TrainingSessionHandler extends SessionHandler { - readonly evalInputNames: readonly string[]; - readonly evalOutputNames: readonly string[]; - - lazyResetGrad(): Promise; - runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - runOptimizerStep(options: InferenceSession.RunOptions): Promise; - runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - - getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; -} - /** * Represent a backend that provides implementation of model inferencing. * @@ -84,14 +56,6 @@ export interface Backend { uriOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions, ): Promise; - - createTrainingSessionHandler?( - checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, - trainModelUriOrBuffer: TrainingSession.UriOrBuffer, - evalModelUriOrBuffer: TrainingSession.UriOrBuffer, - optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, - options: InferenceSession.SessionOptions, - ): Promise; } export { registerBackend } from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 642a897a90d26..d6d9f7fa48790 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import { env as envImpl } from './env-impl.js'; +import { TryGetGlobalType } from './type-helper.js'; export declare namespace Env { export type WasmPathPrefix = string; @@ -14,7 +15,6 @@ export declare namespace Env { * If not modified, the filename of the .wasm file is: * - `ort-wasm-simd-threaded.wasm` for default build * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.wasm` for training build */ wasm?: URL | string; /** @@ -25,7 +25,6 @@ export declare namespace Env { * If not modified, the filename of the .mjs file is: * - `ort-wasm-simd-threaded.mjs` for default build * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.mjs` for training build */ mjs?: URL | string; } @@ -46,17 +45,19 @@ export declare namespace Env { * * This setting is available only when WebAssembly SIMD feature is available in current context. * + * @defaultValue `true` + * * @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD * build is no longer provided. This property will be removed in future release. - * @defaultValue `true` */ simd?: boolean; /** * set or get a boolean value indicating whether to enable trace. * - * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` + * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. */ trace?: boolean; @@ -154,7 +155,7 @@ export declare namespace Env { /** * Set or get the profiling configuration. */ - profiling?: { + profiling: { /** * Set or get the profiling mode. * @@ -177,6 +178,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific power preference. */ powerPreference?: 'low-power' | 'high-performance'; /** @@ -188,6 +192,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific fallback option. */ forceFallbackAdapter?: boolean; /** @@ -200,22 +207,25 @@ export declare namespace Env { * value will be the GPU adapter that created by the underlying WebGPU backend. * * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". - * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. * - * see comments on {@link Tensor.GpuBufferType} + * @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo` + * (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the + * device. When it's available, there is no need to set/get the {@link adapter} property. */ - adapter: unknown; + adapter: TryGetGlobalType<'GPUAdapter'>; /** - * Get the device for WebGPU. - * - * This property is only available after the first WebGPU inference session is created. - * - * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". - * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. + * Set or get the GPU device for WebGPU. * - * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types". + * There are 3 valid scenarios of accessing this property: + * - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend + * to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown. + * - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice + * instance. Returns a `Promise` that resolves to a `GPUDevice` object. + * - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the + * `GPUDevice` object used by the WebGPU backend. */ - readonly device: unknown; + get device(): Promise>; + set device(value: TryGetGlobalType<'GPUDevice'>); /** * Set or get whether validate input content. * diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 3ed56b3c2e812..d75e6a477258d 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -26,4 +26,3 @@ export * from './tensor-factory.js'; export * from './trace.js'; export * from './onnx-model.js'; export * from './onnx-value.js'; -export * from './training-session.js'; diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index af8a8c76c8fe4..e62c6579e8333 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -4,6 +4,7 @@ import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js'; import { OnnxModelOptions } from './onnx-model.js'; import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js'; +import { TryGetGlobalType } from './type-helper.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -282,7 +283,7 @@ export declare namespace InferenceSession { extends WebNNExecutionProviderName, Omit, Required> { - context: unknown /* MLContext */; + context: TryGetGlobalType<'MLContext'>; } /** @@ -291,8 +292,8 @@ export declare namespace InferenceSession { * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice */ export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName { - context: unknown /* MLContext */; - gpuDevice: unknown /* GPUDevice */; + context: TryGetGlobalType<'MLContext'>; + gpuDevice: TryGetGlobalType<'GPUDevice'>; } /** @@ -320,6 +321,7 @@ export declare namespace InferenceSession { * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004 * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008 * COREML_FLAG_CREATE_MLPROGRAM = 0x010 + * COREML_FLAG_USE_CPU_AND_GPU = 0x020 * ``` * * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details. @@ -333,6 +335,7 @@ export declare namespace InferenceSession { * This setting is available only in ONNXRuntime (react-native). */ useCPUOnly?: boolean; + useCPUAndGPU?: boolean; /** * Specify whether to enable CoreML EP on subgraph. * diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index c0e1582c17de5..8feb8d7205fa1 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -179,7 +179,9 @@ export class Tensor implements TensorInterface { type !== 'uint64' && type !== 'int8' && type !== 'uint8' && - type !== 'bool' + type !== 'bool' && + type !== 'uint4' && + type !== 'int4' ) { throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`); } diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 17e2f4d37c91f..05553bd96662b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -4,6 +4,7 @@ import { TensorFactory } from './tensor-factory.js'; import { Tensor as TensorImpl } from './tensor-impl.js'; import { TypedTensorUtils } from './tensor-utils.js'; +import { TryGetGlobalType } from './type-helper.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -131,24 +132,19 @@ export declare namespace Tensor { */ export type TextureDataTypes = 'float32'; + type GpuBufferTypeFallback = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; /** * type alias for WebGPU buffer - * - * The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types" - * requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen - * carefully according to the TypeScript version being used. This means so far there is not a way to keep every - * TypeScript version happy. It turns out that we will easily broke users on some TypeScript version. - * - * for more info see https://github.com/gpuweb/types/issues/127 */ - export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; + export type GpuBufferType = TryGetGlobalType<'GPUBuffer', GpuBufferTypeFallback>; + type MLTensorTypeFallback = { destroy(): void }; /** * type alias for WebNN MLTensor * * The specification for WebNN's MLTensor is currently in flux. */ - export type MLTensorType = unknown; + export type MLTensorType = TryGetGlobalType<'MLTensor', MLTensorTypeFallback>; /** * supported data types for constructing a tensor from a WebGPU buffer @@ -167,7 +163,9 @@ export declare namespace Tensor { | 'uint32' | 'int64' | 'uint64' - | 'bool'; + | 'bool' + | 'uint4' + | 'int4'; /** * represent where the tensor data is stored diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts deleted file mode 100644 index 21dbe5fe51bb9..0000000000000 --- a/js/common/lib/training-session-impl.ts +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { resolveBackendAndExecutionProviders } from './backend-impl.js'; -import { SessionHandler, TrainingSessionHandler } from './backend.js'; -import { InferenceSession as InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { Tensor } from './tensor.js'; -import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js'; - -type SessionOptions = InferenceSession.SessionOptions; -type FeedsType = InferenceSession.FeedsType; -type FetchesType = InferenceSession.FetchesType; -type ReturnType = InferenceSession.ReturnType; -type RunOptions = InferenceSession.RunOptions; - -const noBackendErrMsg: string = - 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files."; - -export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { - this.handler = handler; - this.hasOptimizerModel = hasOptimizerModel; - this.hasEvalModel = hasEvalModel; - } - private handler: TrainingSessionHandler; - private hasOptimizerModel: boolean; - private hasEvalModel: boolean; - - get trainingInputNames(): readonly string[] { - return this.handler.inputNames; - } - get trainingOutputNames(): readonly string[] { - return this.handler.outputNames; - } - - get evalInputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalInputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - get evalOutputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalOutputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - - static async create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: SessionOptions, - ): Promise { - const evalModel: string | Uint8Array = trainingOptions.evalModel || ''; - const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || ''; - const options: SessionOptions = sessionOptions || {}; - - // resolve backend, update session options with validated EPs, and create session handler - const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); - if (backend.createTrainingSessionHandler) { - const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, - trainingOptions.trainModel, - evalModel, - optimizerModel, - optionsWithValidatedEPs, - ); - return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); - } else { - throw new Error(noBackendErrMsg); - } - } - - /** - * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from - * the given parameters to SessionHandler.FetchesType and RunOptions. - * - * @param inputNames the feeds object is checked that they contain all input names in the provided list of input - * names. - * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output - * names. - * @param feeds the required input - * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object - * @param arg2 optional RunOptions object. - * @returns - */ - typeNarrowingForRunStep( - inputNames: readonly string[], - outputNames: readonly string[], - feeds: FeedsType, - arg1?: FetchesType | RunOptions, - arg2?: RunOptions, - ): [SessionHandler.FetchesType, RunOptions] { - const fetches: { [name: string]: OnnxValue | null } = {}; - let options: RunOptions = {}; - // check inputs - if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { - throw new TypeError( - "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.", - ); - } - - let isFetchesEmpty = true; - // determine which override is being used - if (typeof arg1 === 'object') { - if (arg1 === null) { - throw new TypeError('Unexpected argument[1]: cannot be null.'); - } - if (arg1 instanceof Tensor) { - throw new TypeError("'fetches' cannot be a Tensor"); - } - - if (Array.isArray(arg1)) { - if (arg1.length === 0) { - throw new TypeError("'fetches' cannot be an empty array."); - } - isFetchesEmpty = false; - // output names - for (const name of arg1) { - if (typeof name !== 'string') { - throw new TypeError("'fetches' must be a string array or an object."); - } - if (outputNames.indexOf(name) === -1) { - throw new RangeError(`'fetches' contains invalid output name: ${name}.`); - } - fetches[name] = null; - } - - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - // decide whether arg1 is fetches or options - // if any output name is present and its value is valid OnnxValue, we consider it fetches - let isFetches = false; - const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of outputNames) { - if (arg1Keys.indexOf(name) !== -1) { - const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; - if (v === null || v instanceof Tensor) { - isFetches = true; - isFetchesEmpty = false; - fetches[name] = v; - } - } - } - - if (isFetches) { - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - options = arg1 as RunOptions; - } - } - } else if (typeof arg1 !== 'undefined') { - throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'."); - } - - // check if all inputs are in feed - for (const name of inputNames) { - if (typeof feeds[name] === 'undefined') { - throw new Error(`input '${name}' is missing in 'feeds'.`); - } - } - - // if no fetches is specified, we use the full output names list - if (isFetchesEmpty) { - for (const name of outputNames) { - fetches[name] = null; - } - } - - return [fetches, options]; - } - - /** - * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler - * and changes it into a map of Tensors. - * - * @param results - * @returns - */ - convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { - const returnValue: { [name: string]: OnnxValue } = {}; - for (const key in results) { - if (Object.hasOwnProperty.call(results, key)) { - const result = results[key]; - if (result instanceof Tensor) { - returnValue[key] = result; - } else { - returnValue[key] = new Tensor(result.type, result.data, result.dims); - } - } - } - return returnValue; - } - - async lazyResetGrad(): Promise { - await this.handler.lazyResetGrad(); - } - - runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; - runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; - async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep( - this.trainingInputNames, - this.trainingOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runTrainStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } - - async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise { - if (this.hasOptimizerModel) { - await this.handler.runOptimizerStep(options || {}); - } else { - throw new Error('This TrainingSession has no OptimizerModel loaded.'); - } - } - - runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise; - runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise; - async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - if (this.hasEvalModel) { - const [fetches, options] = this.typeNarrowingForRunStep( - this.evalInputNames, - this.evalOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runEvalStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } else { - throw new Error('This TrainingSession has no EvalModel loaded.'); - } - } - - async getParametersSize(trainableOnly = true): Promise { - return this.handler.getParametersSize(trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { - const paramsSize = await this.getParametersSize(trainableOnly); - // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number - // of parameters - if (array.length !== 4 * paramsSize) { - throw new Error( - 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + - 'the model. Please use getParametersSize method to check.', - ); - } - return this.handler.loadParametersBuffer(array, trainableOnly); - } - - async getContiguousParameters(trainableOnly = true): Promise { - return this.handler.getContiguousParameters(trainableOnly); - } - - async release(): Promise { - return this.handler.dispose(); - } -} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts deleted file mode 100644 index 45dcafc46deb5..0000000000000 --- a/js/common/lib/training-session.ts +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js'; - -/* eslint-disable @typescript-eslint/no-redeclare */ - -export declare namespace TrainingSession { - /** - * Either URI file path (string) or Uint8Array containing model or checkpoint information. - */ - type UriOrBuffer = string | Uint8Array; -} - -/** - * Represent a runtime instance of an ONNX training session, - * which contains a model that can be trained, and, optionally, - * an eval and optimizer model. - */ -export interface TrainingSession { - // #region run() - - /** - * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of - * runOptimizerStep. - */ - lazyResetGrad(): Promise; - - /** - * Run TrainStep asynchronously with the given feeds and options. - * - * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for - detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single train step with the given inputs and options. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. - * - * @param options - Optional. A set of options that controls the behavior of model optimizing. - */ - runOptimizerStep(options?: InferenceSession.RunOptions): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - // #endregion - - // #region copy parameters - - /** - * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of - * the parameters) elements of all the parameters in the training state. - * - * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. - */ - getParametersSize(trainableOnly: boolean): Promise; - - /** - * Copies parameter values from the given buffer to the training state. Currently, only supporting models with - * parameters of type Float32. - * - * @param buffer - A Uint8Array representation of Float32 parameters. - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. - */ - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - - /** - * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. - * Currently, only supporting models with parameters of type Float32. - * - * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters - * for which requires_grad is set to true. Default value is true. - * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. - */ - getContiguousParameters(trainableOnly: boolean): Promise; - // #endregion - - // #region release() - - /** - * Release the inference session and the underlying resources. - */ - release(): Promise; - // #endregion - - // #region metadata - - /** - * Get input names of the loaded training model. - */ - readonly trainingInputNames: readonly string[]; - - /** - * Get output names of the loaded training model. - */ - readonly trainingOutputNames: readonly string[]; - - /** - * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalInputNames: readonly string[]; - - /** - * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalOutputNames: readonly string[]; - - // #endregion -} - -/** - * Represents the optional parameters that can be passed into the TrainingSessionFactory. - */ -export interface TrainingSessionCreateOptions { - /** - * URI or buffer for a .ckpt file that contains the checkpoint for the training model. - */ - checkpointState: TrainingSession.UriOrBuffer; - /** - * URI or buffer for the .onnx training file. - */ - trainModel: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx optimizer model file. - */ - optimizerModel?: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx eval model file. - */ - evalModel?: TrainingSession.UriOrBuffer; -} - -/** - * Defines method overload possibilities for creating a TrainingSession. - */ -export interface TrainingSessionFactory { - // #region create() - - /** - * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions - * - * @param trainingOptions specify models and checkpoints to load into the Training Session - * @param sessionOptions specify configuration for training session behavior - * - * @returns Promise that resolves to a TrainingSession object - */ - create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: InferenceSession.SessionOptions, - ): Promise; - - // #endregion -} - -// eslint-disable-next-line @typescript-eslint/naming-convention -export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/common/lib/type-helper.ts b/js/common/lib/type-helper.ts new file mode 100644 index 0000000000000..845ba3018d443 --- /dev/null +++ b/js/common/lib/type-helper.ts @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/** + * A helper type to get certain types if they are declared in global scope. + * + * For example, if you installed "@webgpu/types" as a dev dependency, then `TryGetTypeIfDeclared<'GPUDevice'>` will + * be type `GPUDevice`, otherwise it will be type `unknown`. + * + * + * We don't want to introduce "@webgpu/types" as a dependency of this package because: + * + * (1) For JavaScript users, it's not needed. For TypeScript users, they can install it as dev dependency themselves. + * + * (2) because "@webgpu/types" requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its + * version need to be chosen carefully according to the TypeScript version being used. This means so far there is not a + * way to keep every TypeScript version happy. It turns out that we will easily broke users on some TypeScript version. + * + * for more info see https://github.com/gpuweb/types/issues/127 + * + * Update (2024-08-07): The reason (2) may be no longer valid. Most people should be using TypeScript >= 5.1 by now. + * However, we are still not sure whether introducing "@webgpu/types" as direct dependency is a good idea. We find this + * type helper is useful for TypeScript users. + * + * @ignore + */ +export type TryGetGlobalType = typeof globalThis extends { + [k in Name]: { prototype: infer T }; +} + ? T + : Fallback; diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 450ae2d06e638..475dfe0d4888b 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.20.0'; +export const version = '1.21.0'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 865fa860e98ad..4d92e0f73aa69 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.20.0", + "version": "1.21.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.20.0", + "version": "1.21.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/package.json b/js/common/package.json index 9c941f6486ea9..2e2161c74158c 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.20.0", + "version": "1.21.0", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/common/typedoc.json b/js/common/typedoc.json index 088c7ba4053e6..f9c7e7b19db41 100644 --- a/js/common/typedoc.json +++ b/js/common/typedoc.json @@ -1,6 +1,7 @@ { "entryPoints": ["lib/index.ts"], "excludeInternal": true, + "intentionallyNotExported": ["TryGetGlobalType"], "name": "ONNX Runtime JavaScript API", "readme": "none", "cleanOutputDir": true diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 1ce6d66881c3e..d79a82c572dc2 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11) project (onnxruntime-node) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) add_compile_definitions(NAPI_VERSION=${napi_build_version}) add_compile_definitions(ORT_API_MANUAL_INIT) @@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api) # optional providers option(USE_DML "Build with DirectML support" OFF) +option(USE_WEBGPU "Build with WebGPU support" OFF) option(USE_CUDA "Build with CUDA support" OFF) option(USE_TENSORRT "Build with TensorRT support" OFF) option(USE_COREML "Build with CoreML support" OFF) @@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF) if(USE_DML) add_compile_definitions(USE_DML=1) endif() +if(USE_WEBGPU) + add_compile_definitions(USE_WEBGPU=1) +endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) endif() diff --git a/js/node/README.md b/js/node/README.md index 3f4da7ddd4135..abb91bf05ddf1 100644 --- a/js/node/README.md +++ b/js/node/README.md @@ -14,7 +14,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Requirements -ONNXRuntime works on Node.js v16.x+ (recommend v18.x+) or Electron v15.x+ (recommend v28.x+). +ONNXRuntime works on Node.js v16.x+ (recommend v20.x+) or Electron v15.x+ (recommend v28.x+). The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries. diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 46f8b83b0c5c2..004a3c890a7e4 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -3,12 +3,14 @@ import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common'; -import { Binding, binding } from './binding'; +import { Binding, binding, initOrt } from './binding'; class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) { + initOrt(); + this.#inferenceSession = new binding.InferenceSession(); if (typeof pathOrBuffer === 'string') { this.#inferenceSession.loadModel(pathOrBuffer, options); @@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { readonly outputNames: string[]; startProfiling(): void { - // TODO: implement profiling + // startProfiling is a no-op. + // + // if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded. } endProfiling(): void { - // TODO: implement profiling + this.#inferenceSession.endProfiling(); } async run( diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index d6d592a1665b3..56203f5a5ca02 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession, OnnxValue } from 'onnxruntime-common'; +import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = { @@ -28,6 +28,8 @@ export declare namespace Binding { run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType; + endProfiling(): void; + dispose(): void; } @@ -48,4 +50,35 @@ export const binding = // eslint-disable-next-line @typescript-eslint/naming-convention InferenceSession: Binding.InferenceSessionConstructor; listSupportedBackends: () => Binding.SupportedBackend[]; + initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void; }; + +let ortInitialized = false; +export const initOrt = (): void => { + if (!ortInitialized) { + ortInitialized = true; + let logLevel = 2; + if (env.logLevel) { + switch (env.logLevel) { + case 'verbose': + logLevel = 0; + break; + case 'info': + logLevel = 1; + break; + case 'warning': + logLevel = 2; + break; + case 'error': + logLevel = 3; + break; + case 'fatal': + logLevel = 4; + break; + default: + throw new Error(`Unsupported log level: ${env.logLevel}`); + } + } + binding.initOrtOnce(logLevel, Tensor); + } +}; diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 450ae2d06e638..475dfe0d4888b 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.20.0'; +export const version = '1.21.0'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index a0fc445c16dda..6d3c96e579a47 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.20.0", + "version": "1.21.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.20.0", + "version": "1.21.0", "hasInstallScript": true, "license": "MIT", "os": [ @@ -29,7 +29,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.20.0", + "version": "1.21.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" @@ -276,12 +276,12 @@ "dev": true }, "node_modules/axios": { - "version": "1.6.1", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", - "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", + "version": "1.7.9", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz", + "integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==", "dev": true, "dependencies": { - "follow-redirects": "^1.15.0", + "follow-redirects": "^1.15.6", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } @@ -455,9 +455,9 @@ "dev": true }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -1581,12 +1581,12 @@ "dev": true }, "axios": { - "version": "1.6.1", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", - "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", + "version": "1.7.9", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz", + "integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==", "dev": true, "requires": { - "follow-redirects": "^1.15.0", + "follow-redirects": "^1.15.6", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } @@ -1725,9 +1725,9 @@ "dev": true }, "cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "requires": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", diff --git a/js/node/package.json b/js/node/package.json index 4964d0fc3fd4d..1608f87a3d299 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.20.0", + "version": "1.21.0", "dependencies": { "onnxruntime-common": "file:../common", "tar": "^7.0.1" diff --git a/js/node/script/build.ts b/js/node/script/build.ts index 133d1a0d981a0..dcdcb93377b4c 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator']; const REBUILD = !!buildArgs.rebuild; // --use_dml const USE_DML = !!buildArgs.use_dml; +// --use_webgpu +const USE_WEBGPU = !!buildArgs.use_webgpu; // --use_cuda const USE_CUDA = !!buildArgs.use_cuda; // --use_tensorrt @@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') { if (USE_DML) { args.push('--CDUSE_DML=ON'); } +if (USE_WEBGPU) { + args.push('--CDUSE_WEBGPU=ON'); +} if (USE_CUDA) { args.push('--CDUSE_CUDA=ON'); } diff --git a/js/node/script/install.js b/js/node/script/install.js index b15bc03840599..fef93f9169a2c 100644 --- a/js/node/script/install.js +++ b/js/node/script/install.js @@ -21,6 +21,7 @@ const os = require('os'); const fs = require('fs'); const path = require('path'); const tar = require('tar'); +const { execFileSync } = require('child_process'); const { Readable } = require('stream'); // commandline flag: @@ -58,10 +59,23 @@ if (NO_INSTALL || !shouldInstall) { // Step.2: Download the required binaries const artifactUrl = { - 11: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-${ - ORT_VERSION - }.tgz`, - 12: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-cuda12-${ + get 11() { + // TODO: support ORT Cuda v11 binaries + throw new Error(`CUDA 11 binaries are not supported by this script yet. + +To use ONNX Runtime Node.js binding with CUDA v11 support, please follow the manual steps: + +1. Use "--onnxruntime-node-install-cuda=skip" to skip the auto installation. +2. Navigate to https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/onnxruntime-cuda-11 +3. Download the binaries for your platform and architecture +4. Extract the following binaries to "node_modules/onnxruntime-node/bin/napi-v3/linux/x64: + - libonnxruntime_providers_tensorrt.so + - libonnxruntime_providers_shared.so + - libonnxruntime.so.${ORT_VERSION} + - libonnxruntime_providers_cuda.so +`); + }, + 12: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-${ ORT_VERSION }.tgz`, }[INSTALL_CUDA_FLAG || tryGetCudaVersion()]; @@ -108,9 +122,27 @@ Use "--onnxruntime-node-install-cuda=skip" to skip the installation. You will st function tryGetCudaVersion() { // Should only return 11 or 12. - // TODO: try to get the CUDA version from the system ( `nvcc --version` ) + // try to get the CUDA version from the system ( `nvcc --version` ) + let ver = 12; + try { + const nvccVersion = execFileSync('nvcc', ['--version'], { encoding: 'utf8' }); + const match = nvccVersion.match(/release (\d+)/); + if (match) { + ver = parseInt(match[1]); + if (ver !== 11 && ver !== 12) { + throw new Error(`Unsupported CUDA version: ${ver}`); + } + } + } catch (e) { + if (e?.code === 'ENOENT') { + console.warn('`nvcc` not found. Assuming CUDA 12.'); + } else { + console.warn('Failed to detect CUDA version from `nvcc --version`:', e.message); + } + } - return 11; + // assume CUDA 12 if failed to detect + return ver; } function parseInstallCudaFlag() { diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 057066507621b..23d859351f426 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -11,7 +11,12 @@ #include "tensor_helper.h" #include -Napi::FunctionReference InferenceSessionWrap::constructor; +Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor; +Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor; + +Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() { + return InferenceSessionWrap::ortTensorConstructor; +} Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { #if defined(USE_DML) && defined(_WIN32) @@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Ort::Global::api_ == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); - auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"}; - env.SetInstanceData(ortEnv); + // initialize binding Napi::HandleScope scope(env); Napi::Function func = DefineClass( env, "InferenceSession", - {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run), + {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), + InstanceMethod("run", &InferenceSessionWrap::Run), InstanceMethod("dispose", &InferenceSessionWrap::Dispose), + InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling), InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr), InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)}); - constructor = Napi::Persistent(func); - constructor.SuppressDestruct(); + wrappedSessionConstructor = Napi::Persistent(func); + wrappedSessionConstructor.SuppressDestruct(); exports.Set("InferenceSession", func); Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends); exports.Set("listSupportedBackends", listSupportedBackends); + Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce); + exports.Set("initOrtOnce", initOrtOnce); + return exports; } +Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + Napi::HandleScope scope(env); + + int log_level = info[0].As().Int32Value(); + + Ort::Env* ortEnv = env.GetInstanceData(); + if (ortEnv == nullptr) { + ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"}; + env.SetInstanceData(ortEnv); + } + + Napi::Function tensorConstructor = info[1].As(); + ortTensorConstructor = Napi::Persistent(tensorConstructor); + ortTensorConstructor.SuppressDestruct(); + + return env.Undefined(); +} + InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} @@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) { ? typeInfo.GetTensorTypeAndShapeInfo().GetElementType() : ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); } + + // cache preferred output locations + ParsePreferredOutputLocations(info[argsLength - 1].As(), outputNames_, preferredOutputLocations_); + if (preferredOutputLocations_.size() > 0) { + ioBinding_ = std::make_unique(*session_); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -167,7 +201,8 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { std::vector reuseOutput; size_t inputIndex = 0; size_t outputIndex = 0; - OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release(); + Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault}; try { for (auto& name : inputNames_) { @@ -175,7 +210,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { inputIndex++; inputNames_cstr.push_back(name.c_str()); auto value = feed.Get(name); - inputValues.push_back(NapiValueToOrtValue(env, value, memory_info)); + inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } for (auto& name : outputNames_) { @@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { outputNames_cstr.push_back(name.c_str()); auto value = fetch.Get(name); reuseOutput.push_back(!value.IsNull()); - outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info)); + outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } @@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { runOptions = Ort::RunOptions{}; ParseRunOptions(info[2].As(), runOptions); } + if (preferredOutputLocations_.size() == 0) { + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, + inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], + inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], + outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); - session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, - inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], - inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], - outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); + Napi::Object result = Napi::Object::New(env); - Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i]))); + } + return scope.Escape(result); + } else { + // IO binding + ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env, + "Preferred output locations must have the same size as output names."); - for (size_t i = 0; i < outputIndex; i++) { - result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i])); - } + for (size_t i = 0; i < inputIndex; i++) { + ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]); + } + for (size_t i = 0; i < outputIndex; i++) { + // TODO: support preallocated output tensor (outputValues[i]) + + if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) { + ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo); + } else { + ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo); + } + } + + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_); + + auto outputs = ioBinding_->GetOutputValues(); + ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch."); - return scope.Escape(result); + Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i]))); + } + return scope.Escape(result); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -218,6 +281,8 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + this->ioBinding_.reset(nullptr); + this->defaultRunOptions_.reset(nullptr); this->session_.reset(nullptr); @@ -225,6 +290,20 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { return env.Undefined(); } +Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + + Napi::EscapableHandleScope scope(env); + + Ort::AllocatorWithDefaultOptions allocator; + + auto filename = session_->EndProfilingAllocated(allocator); + Napi::String filenameValue = Napi::String::From(env, filename.get()); + return scope.Escape(filenameValue); +} + Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); @@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo #ifdef USE_DML result.Set(result.Length(), createObject("dml", true)); #endif +#ifdef USE_WEBGPU + result.Set(result.Length(), createObject("webgpu", true)); +#endif #ifdef USE_CUDA result.Set(result.Length(), createObject("cuda", false)); #endif diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index effdd83e3aa02..0b3dd1178c807 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -12,9 +12,22 @@ class InferenceSessionWrap : public Napi::ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports); + static Napi::FunctionReference& GetTensorConstructor(); + InferenceSessionWrap(const Napi::CallbackInfo& info); private: + /** + * [sync] initialize ONNX Runtime once. + * + * This function must be called before any other functions. + * + * @param arg0 a number specifying the log level. + * + * @returns undefined + */ + static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info); + /** * [sync] list supported backend list * @returns array with objects { "name": "cpu", requirementsInstalled: true } @@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap { */ Napi::Value Dispose(const Napi::CallbackInfo& info); + /** + * [sync] end the profiling. + * @param nothing + * @returns nothing + * @throw nothing + */ + Napi::Value EndProfiling(const Napi::CallbackInfo& info); + // private members // persistent constructor - static Napi::FunctionReference constructor; + static Napi::FunctionReference wrappedSessionConstructor; + static Napi::FunctionReference ortTensorConstructor; // session objects bool initialized_; @@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap { std::vector outputNames_; std::vector outputTypes_; std::vector outputTensorElementDataTypes_; + + // preferred output locations + std::vector preferredOutputLocations_; + std::unique_ptr ioBinding_; }; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 0ed1ba08e6bf7..8c1d7ca06b8c3 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -6,15 +6,20 @@ #include #include +#include #include "common.h" #include "session_options_helper.h" +#include "tensor_helper.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" #endif +#ifdef USE_WEBGPU +#include "core/providers/webgpu/webgpu_provider_factory.h" +#endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif @@ -36,7 +41,12 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_COREML int coreMlFlags = 0; +#endif +#ifdef USE_WEBGPU + std::unordered_map webgpu_options; +#endif if (epValue.IsString()) { name = epValue.As().Utf8Value(); } else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As().Has("name") || @@ -49,9 +59,23 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); } +#endif +#ifdef USE_WEBGPU + for (const auto& nameIter : obj.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + if (name != "name") { + Napi::Value valueVar = obj.Get(nameVar); + ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'."); + std::string value = valueVar.As().Utf8Value(); + webgpu_options[name] = value; + } + } +#endif } // CPU execution provider @@ -77,6 +101,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } else if (name == "dml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId)); #endif +#ifdef USE_WEBGPU + } else if (name == "webgpu") { + sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options); +#endif #ifdef USE_COREML } else if (name == "coreml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags)); @@ -95,6 +123,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } } +void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) { + for (const auto& kvp : obj) { + auto key = kvp.first.As().Utf8Value(); + Napi::Value value = kvp.second; + if (value.IsObject()) { + IterateExtraOptions(prefix + key + ".", value.As(), sessionOptions); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(), + "Invalid argument: sessionOptions.extra value must be a string in Node.js binding."); + std::string entry = prefix + key; + auto val = value.As().Utf8Value(); + sessionOptions.AddConfigEntry(entry.c_str(), val.c_str()); + } + } +} + void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { @@ -162,6 +206,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio } } + // optimizedModelFilePath + if (options.Has("optimizedModelFilePath")) { + auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath"); + ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.optimizedModelFilePath must be a string."); +#ifdef _WIN32 + auto str = optimizedModelFilePathValue.As().Utf16Value(); + std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}}; +#else + std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As().Utf8Value()}; +#endif + sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str()); + } + + // extra + if (options.Has("extra")) { + auto extraValue = options.Get("extra"); + ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(), + "Invalid argument: sessionOptions.extra must be an object."); + IterateExtraOptions("", extraValue.As(), sessionOptions); + } + // execution mode if (options.Has("executionMode")) { auto executionModeValue = options.Get("executionMode"); @@ -195,4 +261,118 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio sessionOptions.SetLogSeverityLevel(static_cast(logLevelNumber)); } + + // Profiling + if (options.Has("enableProfiling")) { + auto enableProfilingValue = options.Get("enableProfiling"); + ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(), + "Invalid argument: sessionOptions.enableProfiling must be a boolean value."); + + if (enableProfilingValue.As().Value()) { + ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix is required" + " when sessionOptions.enableProfiling is set to true."); + auto profileFilePrefixValue = options.Get("profileFilePrefix"); + ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix must be a string." + " when sessionOptions.enableProfiling is set to true."); +#ifdef _WIN32 + auto str = profileFilePrefixValue.As().Utf16Value(); + std::basic_string profileFilePrefix = std::wstring{str.begin(), str.end()}; +#else + std::basic_string profileFilePrefix = profileFilePrefixValue.As().Utf8Value(); +#endif + sessionOptions.EnableProfiling(profileFilePrefix.c_str()); + } else { + sessionOptions.DisableProfiling(); + } + } + + // external data + if (options.Has("externalData")) { + auto externalDataValue = options.Get("externalData"); + ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), + "Invalid argument: sessionOptions.externalData must be an array."); + auto externalData = externalDataValue.As(); + std::vector> paths; + std::vector buffs; + std::vector sizes; + + for (const auto& kvp : externalData) { + Napi::Value value = kvp.second; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(), + "Invalid argument: sessionOptions.externalData value must be an object in Node.js binding."); + Napi::Object obj = value.As(); + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), + "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); +#ifdef _WIN32 + auto path = obj.Get("path").As().Utf16Value(); + paths.push_back(std::wstring{path.begin(), path.end()}); +#else + auto path = obj.Get("path").As().Utf8Value(); + paths.push_back(path); +#endif + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || + !obj.Get("data").IsBuffer() || + !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), + options.Env(), + "Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding."); + + auto data = obj.Get("data"); + if (data.IsBuffer()) { + buffs.push_back(data.As>().Data()); + sizes.push_back(data.As>().Length()); + } else { + auto typedArray = data.As(); + buffs.push_back(reinterpret_cast(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset()); + sizes.push_back(typedArray.ByteLength()); + } + } + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes); + } +} + +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations) { + if (options.Has("preferredOutputLocation")) { + auto polValue = options.Get("preferredOutputLocation"); + if (polValue.IsNull() || polValue.IsUndefined()) { + return; + } + if (polValue.IsString()) { + DataLocation location = ParseDataLocation(polValue.As().Utf8Value()); + ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(), + "Invalid argument: preferredOutputLocation must be an array or a valid string."); + + if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) { + preferredOutputLocations.resize(outputNames.size(), location); + } + } else if (polValue.IsObject()) { + preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU); + + auto pol = polValue.As(); + for (const auto& nameIter : pol.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + // find the name in outputNames + auto it = std::find(outputNames.begin(), outputNames.end(), name); + ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(), + "Invalid argument: \"", name, "\" is not a valid output name."); + + Napi::Value value = pol.Get(nameVar); + DataLocation location = DATA_LOCATION_NONE; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As().Utf8Value())) == DATA_LOCATION_NONE, + options.Env(), + "Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string."); + + size_t index = it - outputNames.begin(); + preferredOutputLocations[index] = location; + } + + if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) { + preferredOutputLocations.clear(); + } + } else { + ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string."); + } + } } diff --git a/js/node/src/session_options_helper.h b/js/node/src/session_options_helper.h index c0a9ae0d683e9..c6338f28e03c6 100644 --- a/js/node/src/session_options_helper.h +++ b/js/node/src/session_options_helper.h @@ -11,3 +11,6 @@ struct SessionOptions; // parse a Javascript session options object and fill the native SessionOptions object. void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions); + +// parse a Javascript session options object and prepare the preferred output locations. +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations); \ No newline at end of file diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 54f1c5a09906e..27eb9b65c62d3 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -8,6 +8,7 @@ #include "common.h" #include "tensor_helper.h" +#include "inference_session_wrap.h" // make sure consistent with origin definition static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime"); @@ -100,7 +101,7 @@ const std::unordered_map DATA_TYPE_NAME_ {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}}; // currently only support tensor -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) { +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info) { ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object."); // check 'dims' @@ -110,6 +111,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto dimsArray = dimsValue.As(); auto len = dimsArray.Length(); + size_t elementSize = 1; std::vector dims; if (len > 0) { dims.reserve(len); @@ -122,17 +124,26 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* "Tensor.dims[", i, "] is invalid: ", dimDouble); int64_t dim = static_cast(dimDouble); dims.push_back(dim); + elementSize *= dim; } } + // check 'location' + auto tensorLocationValue = tensorObject.Get("location"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorLocationValue.IsString(), env, "Tensor.location must be a string."); + DataLocation tensorLocation = ParseDataLocation(tensorLocationValue.As().Utf8Value()); + ORT_NAPI_THROW_RANGEERROR_IF(tensorLocation == DATA_LOCATION_NONE, env, "Tensor.location is not supported."); + // check 'data' and 'type' - auto tensorDataValue = tensorObject.Get("data"); auto tensorTypeValue = tensorObject.Get("type"); ORT_NAPI_THROW_TYPEERROR_IF(!tensorTypeValue.IsString(), env, "Tensor.type must be a string."); auto tensorTypeString = tensorTypeValue.As().Utf8Value(); if (tensorTypeString == "string") { + auto tensorDataValue = tensorObject.Get("data"); + + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_CPU, env, "Tensor.location must be 'cpu' for string tensors."); ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsArray(), env, "Tensor.data must be an array for string tensors."); auto tensorDataArray = tensorDataValue.As(); @@ -162,29 +173,42 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto v = DATA_TYPE_NAME_TO_ID_MAP.find(tensorTypeString); ORT_NAPI_THROW_TYPEERROR_IF(v == DATA_TYPE_NAME_TO_ID_MAP.end(), env, "Tensor.type is not supported: ", tensorTypeString); - ONNXTensorElementDataType elemType = v->second; - ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, - "Tensor.data must be a typed array for numeric tensor."); + if (tensorLocation == DATA_LOCATION_CPU) { + auto tensorDataValue = tensorObject.Get("data"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, + "Tensor.data must be a typed array for numeric tensor."); + + auto tensorDataTypedArray = tensorDataValue.As(); + auto typedArrayType = tensorDataValue.As().TypedArrayType(); + ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, + "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", + tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); - auto tensorDataTypedArray = tensorDataValue.As(); - auto typedArrayType = tensorDataValue.As().TypedArrayType(); - ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, - "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", - tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); + char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); + size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); + return Ort::Value::CreateTensor(cpu_memory_info, buffer + bufferByteOffset, bufferByteLength, + dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_GPU_BUFFER, env, "Tensor.location must be 'gpu-buffer' for IO binding."); - char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); - size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - size_t bufferByteLength = tensorDataTypedArray.ByteLength(); - return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, - dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + auto gpuBufferValue = tensorObject.Get("gpuBuffer"); + // nodejs: tensor.gpuBuffer is no longer a GPUBuffer in nodejs. we assume it is an external object (bind the OrtValue pointer). + ORT_NAPI_THROW_TYPEERROR_IF(!gpuBufferValue.IsExternal(), env, "Tensor.gpuBuffer must be an external object."); + Ort::Value dataValue(gpuBufferValue.As>().Data()); + void* gpuBuffer = dataValue.GetTensorMutableRawData(); + dataValue.release(); + + size_t dataByteLength = DATA_TYPE_ELEMENT_SIZE_MAP[elemType] * elementSize; + return Ort::Value::CreateTensor(webgpu_memory_info, gpuBuffer, dataByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } } } -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) { Napi::EscapableHandleScope scope(env); - auto returnValue = Napi::Object::New(env); auto typeInfo = value.GetTypeInfo(); auto onnxType = typeInfo.GetONNXType(); @@ -197,24 +221,26 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { // type auto typeCstr = DATA_TYPE_ID_TO_NAME_MAP[elemType]; ORT_NAPI_THROW_ERROR_IF(typeCstr == nullptr, env, "Tensor type (", elemType, ") is not supported."); - - returnValue.Set("type", Napi::String::New(env, typeCstr)); + auto type = Napi::String::New(env, typeCstr); // dims const size_t dimsCount = tensorTypeAndShapeInfo.GetDimensionsCount(); - std::vector dims; + std::vector dimsVector; if (dimsCount > 0) { - dims = tensorTypeAndShapeInfo.GetShape(); + dimsVector = tensorTypeAndShapeInfo.GetShape(); } - auto dimsArray = Napi::Array::New(env, dimsCount); + auto dims = Napi::Array::New(env, dimsCount); for (uint32_t i = 0; i < dimsCount; i++) { - dimsArray[i] = dims[i]; + dims[i] = dimsVector[i]; } - returnValue.Set("dims", dimsArray); + + // location + auto memoryInfo = value.GetTensorMemoryInfo(); + bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU && + memoryInfo.GetAllocatorName() == "WebGPU_Buffer"; // size auto size = tensorTypeAndShapeInfo.GetElementCount(); - returnValue.Set("size", Napi::Number::From(env, size)); // data if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { @@ -234,20 +260,48 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { i == size - 1 ? tempBufferLength - tempOffsets[i] : tempOffsets[i + 1] - tempOffsets[i]); } } - returnValue.Set("data", Napi::Value(env, stringArray)); + + // new Tensor("string", stringArray /* string[] */, dims /* number[] */) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New({Napi::String::New(env, "string"), stringArray, dims})); } else { // number data - // TODO: optimize memory - auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); - if (size > 0) { - memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (isGpuBuffer) { + // Tensor.fromGpuBuffer(buffer, options) + Napi::Function tensorFromGpuBuffer = InferenceSessionWrap::GetTensorConstructor().Value().Get("fromGpuBuffer").As(); + OrtValue* underlyingOrtValue = value.release(); + + auto options = Napi::Object::New(env); + options.Set("dataType", type); + options.Set("dims", dims); + options.Set("dispose", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + Ort::GetApi().ReleaseValue(reinterpret_cast(info.Data())); + return info.Env().Undefined(); + }, + "dispose", underlyingOrtValue)); + options.Set("download", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + NAPI_THROW("not implemented"); + }, + "download", underlyingOrtValue)); + + return scope.Escape(tensorFromGpuBuffer.Call({Napi::External::New(env, underlyingOrtValue), options})); + } else { + // TODO: optimize memory + auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (size > 0) { + memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + } + napi_value typedArrayData; + napi_status status = + napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); + NAPI_THROW_IF_FAILED(env, status, Napi::Value); + + // new Tensor(type, typedArrayData, dims) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New( + {type, + Napi::Value(env, typedArrayData), + dims})); } - napi_value typedArrayData; - napi_status status = - napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); - NAPI_THROW_IF_FAILED(env, status, Napi::Value); - returnValue.Set("data", Napi::Value(env, typedArrayData)); } - - return scope.Escape(returnValue); } diff --git a/js/node/src/tensor_helper.h b/js/node/src/tensor_helper.h index 56b399ccc24ee..4a51e5240602a 100644 --- a/js/node/src/tensor_helper.h +++ b/js/node/src/tensor_helper.h @@ -9,7 +9,32 @@ #include "onnxruntime_cxx_api.h" // convert a Javascript OnnxValue object to an OrtValue object -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info); +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info); // convert an OrtValue object to a Javascript OnnxValue object -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value); +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value); + +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4, + DATA_LOCATION_ML_TENSOR = 5 +}; + +inline DataLocation ParseDataLocation(const std::string& location) { + if (location == "cpu") { + return DATA_LOCATION_CPU; + } else if (location == "cpu-pinned") { + return DATA_LOCATION_CPU_PINNED; + } else if (location == "texture") { + return DATA_LOCATION_TEXTURE; + } else if (location == "gpu-buffer") { + return DATA_LOCATION_GPU_BUFFER; + } else if (location == "ml-tensor") { + return DATA_LOCATION_ML_TENSOR; + } else { + return DATA_LOCATION_NONE; + } +} diff --git a/js/node/tsconfig.json b/js/node/tsconfig.json index c154c3e148ed0..0401fb9609ad6 100644 --- a/js/node/tsconfig.json +++ b/js/node/tsconfig.json @@ -1,7 +1,8 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "outDir": "dist" + "outDir": "dist", + "declaration": true }, "include": ["lib"] } diff --git a/js/package-lock.json b/js/package-lock.json index 58a13a9112116..f4401c6e98c75 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -7,6 +7,7 @@ "license": "MIT", "devDependencies": { "@types/fs-extra": "^11.0.2", + "@types/global-agent": "^2.1.3", "@types/mocha": "^10.0.2", "@types/node": "^18.14.6", "@types/npmlog": "^4.1.4", @@ -23,6 +24,7 @@ "eslint-plugin-prefer-arrow": "^1.2.3", "eslint-plugin-unicorn": "^48.0.1", "fs-extra": "^11.1.1", + "global-agent": "^3.0", "jszip": "^3.10.1", "mocha": "^10.2.0", "npmlog": "^7.0.1", @@ -710,6 +712,13 @@ "@types/node": "*" } }, + "node_modules/@types/global-agent": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@types/global-agent/-/global-agent-2.1.3.tgz", + "integrity": "sha512-rGtZZcgZcKWuKNTkGBGsqyOQ7Nn2MjXh4+xeZbf+5b5KMUx8H1rTqLRackxos7pUlreszbYjQcop5JvqCnZlLw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/json-schema": { "version": "7.0.15", "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", @@ -1289,6 +1298,14 @@ "node": ">=8" } }, + "node_modules/boolean": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz", + "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==", + "deprecated": "Package no longer supported. Contact Support at https://www.npmjs.com/support for more info.", + "dev": true, + "license": "MIT" + }, "node_modules/brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -1556,9 +1573,9 @@ "dev": true }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, "dependencies": { "path-key": "^3.1.0", @@ -1640,6 +1657,13 @@ "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", "dev": true }, + "node_modules/detect-node": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", + "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==", + "dev": true, + "license": "MIT" + }, "node_modules/diff": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz", @@ -1791,6 +1815,13 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/es6-error": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", + "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==", + "dev": true, + "license": "MIT" + }, "node_modules/esbuild": { "version": "0.19.3", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz", @@ -2504,6 +2535,24 @@ "node": ">=10.13.0" } }, + "node_modules/global-agent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", + "integrity": "sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "boolean": "^3.0.1", + "es6-error": "^4.1.1", + "matcher": "^3.0.0", + "roarr": "^2.15.3", + "semver": "^7.3.2", + "serialize-error": "^7.0.1" + }, + "engines": { + "node": ">=10.0" + } + }, "node_modules/globals": { "version": "13.24.0", "resolved": "https://registry.npmjs.org/globals/-/globals-13.24.0.tgz", @@ -3153,6 +3202,13 @@ "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", "dev": true }, + "node_modules/json-stringify-safe": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz", + "integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==", + "dev": true, + "license": "ISC" + }, "node_modules/json5": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz", @@ -3272,6 +3328,19 @@ "node": ">=10" } }, + "node_modules/matcher": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", + "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==", + "dev": true, + "license": "MIT", + "dependencies": { + "escape-string-regexp": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -4075,6 +4144,24 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/roarr": { + "version": "2.15.4", + "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz", + "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "boolean": "^3.0.1", + "detect-node": "^2.0.4", + "globalthis": "^1.0.1", + "json-stringify-safe": "^5.0.1", + "semver-compare": "^1.0.0", + "sprintf-js": "^1.1.2" + }, + "engines": { + "node": ">=8.0" + } + }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -4157,6 +4244,42 @@ "node": ">=10" } }, + "node_modules/semver-compare": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz", + "integrity": "sha512-YM3/ITh2MJ5MtzaM429anh+x2jiLVjqILF4m4oyQB18W7Ggea7BfqdH/wGMK7dDiMghv/6WG7znWMwUDzJiXow==", + "dev": true, + "license": "MIT" + }, + "node_modules/serialize-error": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz", + "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==", + "dev": true, + "license": "MIT", + "dependencies": { + "type-fest": "^0.13.1" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/serialize-error/node_modules/type-fest": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", + "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==", + "dev": true, + "license": "(MIT OR CC0-1.0)", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/set-blocking": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", @@ -4284,6 +4407,13 @@ "integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==", "dev": true }, + "node_modules/sprintf-js": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", + "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==", + "dev": true, + "license": "BSD-3-Clause" + }, "node_modules/string_decoder": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", @@ -5198,6 +5328,12 @@ "@types/node": "*" } }, + "@types/global-agent": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@types/global-agent/-/global-agent-2.1.3.tgz", + "integrity": "sha512-rGtZZcgZcKWuKNTkGBGsqyOQ7Nn2MjXh4+xeZbf+5b5KMUx8H1rTqLRackxos7pUlreszbYjQcop5JvqCnZlLw==", + "dev": true + }, "@types/json-schema": { "version": "7.0.15", "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", @@ -5588,6 +5724,12 @@ "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", "dev": true }, + "boolean": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz", + "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==", + "dev": true + }, "brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -5780,9 +5922,9 @@ "dev": true }, "cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, "requires": { "path-key": "^3.1.0", @@ -5838,6 +5980,12 @@ "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", "dev": true }, + "detect-node": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", + "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==", + "dev": true + }, "diff": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz", @@ -5965,6 +6113,12 @@ "is-symbol": "^1.0.2" } }, + "es6-error": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", + "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==", + "dev": true + }, "esbuild": { "version": "0.19.3", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz", @@ -6511,6 +6665,20 @@ "is-glob": "^4.0.3" } }, + "global-agent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", + "integrity": "sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q==", + "dev": true, + "requires": { + "boolean": "^3.0.1", + "es6-error": "^4.1.1", + "matcher": "^3.0.0", + "roarr": "^2.15.3", + "semver": "^7.3.2", + "serialize-error": "^7.0.1" + } + }, "globals": { "version": "13.24.0", "resolved": "https://registry.npmjs.org/globals/-/globals-13.24.0.tgz", @@ -6956,6 +7124,12 @@ "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", "dev": true }, + "json-stringify-safe": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz", + "integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==", + "dev": true + }, "json5": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz", @@ -7052,6 +7226,15 @@ "yallist": "^4.0.0" } }, + "matcher": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", + "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==", + "dev": true, + "requires": { + "escape-string-regexp": "^4.0.0" + } + }, "merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -7636,6 +7819,20 @@ "glob": "^7.1.3" } }, + "roarr": { + "version": "2.15.4", + "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz", + "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==", + "dev": true, + "requires": { + "boolean": "^3.0.1", + "detect-node": "^2.0.4", + "globalthis": "^1.0.1", + "json-stringify-safe": "^5.0.1", + "semver-compare": "^1.0.0", + "sprintf-js": "^1.1.2" + } + }, "run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -7691,6 +7888,29 @@ "lru-cache": "^6.0.0" } }, + "semver-compare": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz", + "integrity": "sha512-YM3/ITh2MJ5MtzaM429anh+x2jiLVjqILF4m4oyQB18W7Ggea7BfqdH/wGMK7dDiMghv/6WG7znWMwUDzJiXow==", + "dev": true + }, + "serialize-error": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz", + "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==", + "dev": true, + "requires": { + "type-fest": "^0.13.1" + }, + "dependencies": { + "type-fest": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", + "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==", + "dev": true + } + } + }, "set-blocking": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", @@ -7800,6 +8020,12 @@ "integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==", "dev": true }, + "sprintf-js": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", + "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==", + "dev": true + }, "string_decoder": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", diff --git a/js/package.json b/js/package.json index a3bd18adce98e..7385ed31eb075 100644 --- a/js/package.json +++ b/js/package.json @@ -1,6 +1,7 @@ { "devDependencies": { "@types/fs-extra": "^11.0.2", + "@types/global-agent": "^2.1.3", "@types/mocha": "^10.0.2", "@types/node": "^18.14.6", "@types/npmlog": "^4.1.4", @@ -17,6 +18,7 @@ "eslint-plugin-prefer-arrow": "^1.2.3", "eslint-plugin-unicorn": "^48.0.1", "fs-extra": "^11.1.1", + "global-agent": "^3.0", "jszip": "^3.10.1", "mocha": "^10.2.0", "npmlog": "^7.0.1", diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index 825990eba0fb8..521866ff0f3e2 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { - classpath 'com.android.tools.build:gradle:4.1.2' + classpath 'com.android.tools.build:gradle:7.4.2' // noinspection DifferentKotlinGradleVersion } } @@ -221,9 +221,8 @@ dependencies { api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION api "org.mockito:mockito-core:2.28.2" - androidTestImplementation "androidx.test:runner:1.1.0" - androidTestImplementation "androidx.test:rules:1.1.0" - + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" implementation "junit:junit:4.12" androidTestImplementation "com.linkedin.dexmaker:dexmaker-mockito-inline-extended:2.28.1" diff --git a/js/react_native/android/gradle.properties b/js/react_native/android/gradle.properties index 465b04d1f5813..8fe6e40d76911 100644 --- a/js/react_native/android/gradle.properties +++ b/js/react_native/android/gradle.properties @@ -4,7 +4,7 @@ # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. # Default value: -Xmx1024m -XX:MaxPermSize=256m -# org.gradle.jvmargs=-Xmx2048m -XX:MaxPermSize=512m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 +org.gradle.jvmargs=-Xmx4096m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 # # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar index 62d4c053550b9..e6441136f3d4b 100644 Binary files a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar and b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties index 51d930a381f3a..381baa9cef1ec 100644 --- a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties +++ b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=7faa7198769f872826c8ef4f1450f839ec27f0b4d5d1e51bade63667cbccd205 -distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d +distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip +networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/js/react_native/android/gradlew b/js/react_native/android/gradlew index fbd7c515832da..1aa94a4269074 100755 --- a/js/react_native/android/gradlew +++ b/js/react_native/android/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,99 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +119,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,88 +130,120 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/js/react_native/android/gradlew.bat b/js/react_native/android/gradlew.bat index 5093609d512a9..25da30dbdeee9 100644 --- a/js/react_native/android/gradlew.bat +++ b/js/react_native/android/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,8 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,13 +41,13 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -54,31 +55,16 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line @@ -86,17 +72,19 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/js/react_native/e2e/android/app/build.gradle b/js/react_native/e2e/android/app/build.gradle index 8a84b0d5065a8..526259e3f8d8f 100644 --- a/js/react_native/e2e/android/app/build.gradle +++ b/js/react_native/e2e/android/app/build.gradle @@ -193,7 +193,7 @@ dependencies { implementation "com.facebook.react:react-native:+" // From node_modules implementation "androidx.swiperefreshlayout:swiperefreshlayout:1.0.0" - implementation 'androidx.test.ext:junit:1.1.3' + implementation 'androidx.test.ext:junit:1.1.5' debugImplementation("com.facebook.flipper:flipper:${FLIPPER_VERSION}") { exclude group:'com.facebook.fbjni' } @@ -213,9 +213,9 @@ dependencies { implementation jscFlavor } - androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' - androidTestImplementation 'androidx.test:runner:1.4.0' - androidTestImplementation 'androidx.test:rules:1.4.0' + androidTestImplementation "androidx.test.espresso:espresso-core:3.5.0" + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" implementation project(':onnxruntime-react-native') // specify ORT dependency here so it can be found in libs flatDir repository diff --git a/js/react_native/ios/OnnxruntimeModule.mm b/js/react_native/ios/OnnxruntimeModule.mm index 9da76034fc1ad..16e64d8ed98b4 100644 --- a/js/react_native/ios/OnnxruntimeModule.mm +++ b/js/react_native/ios/OnnxruntimeModule.mm @@ -389,6 +389,8 @@ - (NSDictionary*)run:(NSString*)url if (useOptions) { if ([[executionProvider objectForKey:@"useCPUOnly"] boolValue]) { coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + } else if ([[executionProvider objectForKey:@"useCPUAndGPU"] boolValue]) { + coreml_flags |= COREML_FLAG_USE_CPU_AND_GPU; } if ([[executionProvider objectForKey:@"enableOnSubgraph"] boolValue]) { coreml_flags |= COREML_FLAG_ENABLE_ON_SUBGRAPH; diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 450ae2d06e638..475dfe0d4888b 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.20.0'; +export const version = '1.21.0'; diff --git a/js/react_native/package.json b/js/react_native/package.json index 20b5d02ff233e..ff798530f59d3 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.20.0", + "version": "1.21.0", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 99c03d2e7bf02..fd424f1f76089 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.20.0" + version "1.21.0" open@^6.2.0: version "6.4.0" diff --git a/js/scripts/prepare-onnx-node-tests.ts b/js/scripts/prepare-onnx-node-tests.ts index 91aa63e9e6aff..02c33892d57d5 100644 --- a/js/scripts/prepare-onnx-node-tests.ts +++ b/js/scripts/prepare-onnx-node-tests.ts @@ -10,6 +10,8 @@ import * as path from 'path'; import { downloadZip, extractFile } from './utils'; const TEST_DATA_OPSET_VERSIONS = [ + ['opset21', '1.16.2'], + ['opset20', '1.15.0'], ['opset19', '1.14.0'], ['opset18', '1.13.1'], ['opset17', '1.12.1'], diff --git a/js/scripts/utils.ts b/js/scripts/utils.ts index e22eeb1bd9217..5d032dc01957c 100644 --- a/js/scripts/utils.ts +++ b/js/scripts/utils.ts @@ -2,9 +2,15 @@ // Licensed under the MIT License. import { WriteStream } from 'fs'; +import { bootstrap as globalAgentBootstrap } from 'global-agent'; import * as https from 'https'; import { JSZipObject } from 'jszip'; +// Bootstrap global-agent to honor the proxy settings in +// environment variables, e.g. GLOBAL_AGENT_HTTPS_PROXY. +// See https://github.com/gajus/global-agent/blob/v3.0.0/README.md#environment-variables for details. +globalAgentBootstrap(); + export const downloadZip = async (url: string): Promise => new Promise((resolve, reject) => { https.get(url, (res) => { diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 1c140de448430..5c8748d75c2bc 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -21,11 +21,11 @@ Do not modify directly.* | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | | Attention | com.microsoft(1+) | need implementing mask and past/present | -| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | +| AveragePool | ai.onnx(7-9,10,11-18,19+); com.ms.internal.nhwc(7-9,10,11-18,19+) | need perf optimization; need implementing activation | | BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | -| Cast | ai.onnx(6-8,9-12,13-18,19+) | | +| Cast | ai.onnx(6-8,9-12,13-18,19-20,21+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | | Concat | ai.onnx(1-3,4-10,11-12,13+) | | @@ -44,21 +44,23 @@ Do not modify directly.* | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | | FastGelu | com.microsoft(1+) | | -| Flatten | ai.onnx(1-8,9-10,11-12,13+) | | +| Flatten | ai.onnx(1-8,9-10,11-12,13-20,21+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherBlockQuantized | com.microsoft(1+) | | | GatherElements | ai.onnx(11-12,13+) | | +| GatherND | ai.onnx(11,12,13+) | | | Gelu | ai.onnx(20+); com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| GridSample | ai.onnx(16-19); com.ms.internal.nhwc(16-19) | | | GroupQueryAttention | com.microsoft(1+) | | | HardSigmoid | ai.onnx(6+) | | -| If | ai.onnx(1-10,11-12,13-18,19+) | | +| If | ai.onnx(1-10,11-12,13-18,19-20,21+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(1-16,17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | @@ -74,7 +76,7 @@ Do not modify directly.* | MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | -| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | +| Pad | ai.onnx(2-10,11-12,13-17,18,19-20,21+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | | QuickGelu | com.microsoft(1+) | | | Range | ai.onnx(11+) | | @@ -83,9 +85,9 @@ Do not modify directly.* | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceLogSum | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceLogSumExp | ai.onnx(1-10,11-12,13-17,18+) | | -| ReduceMax | ai.onnx(1-10,11,12,13-17,18+) | | +| ReduceMax | ai.onnx(1-10,11,12,13-17,18-19,20+) | | | ReduceMean | ai.onnx(1-10,11-12,13-17,18+) | | -| ReduceMin | ai.onnx(1-10,11,12,13-17,18+) | | +| ReduceMin | ai.onnx(1-10,11,12,13-17,18-19,20+) | | | ReduceProd | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceSum | ai.onnx(1-10,11-12,13+) | | | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | @@ -93,6 +95,7 @@ Do not modify directly.* | Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | RotaryEmbedding | com.microsoft(1+) | | +| ScatterND | ai.onnx(11-12,13-15,16-17,18+) | | | Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | SimplifiedLayerNormalization | ai.onnx(1+) | | @@ -104,12 +107,12 @@ Do not modify directly.* | Softmax | ai.onnx(1-10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | -| Squeeze | ai.onnx(1-10,11-12,13+) | | +| Squeeze | ai.onnx(1-10,11-12,13-20,21+) | | | Sub | ai.onnx(7-12,13,14+) | | | Tan | ai.onnx(7+) | | | Tanh | ai.onnx(6-12,13+) | | | ThresholdedRelu | ai.onnx(10+) | | | Tile | ai.onnx(6-12,13+) | | -| Transpose | ai.onnx(1-12,13+) | need perf optimization | -| Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Transpose | ai.onnx(1-12,13-20,21+) | need perf optimization | +| Unsqueeze | ai.onnx(1-10,11-12,13-20,21+) | | | Where | ai.onnx(9-15,16+) | | diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index f696264aeead7..af7348dba532f 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -13,6 +13,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim |:------:|:------:|:------:|:-:|:-:|:------| | Abs | ai.onnx(7-12, 13+) | abs | ✓ | ✓ | | | Add | ai.onnx(7-12, 13, 14+) | add | ✓ | ✓ | | +| And | ai.onnx(7+) | logicalAnd | ✗ | ✓ | | | ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | | | ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | @@ -24,9 +25,11 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | +| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | -| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | | +| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✓ | ✓ | The shape of x_scale should be a subsample of the shape of input | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | +| Einsum | ai.onnx(12+) | reshape, transpose, matmul, reduceSum, mul, triangular | ✓ | ✓ | | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | | Erf | ai.onnx(7-9, 10-12, 13+) | erf | ✓ | ✓ | | @@ -35,6 +38,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | | | Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | | +| GatherElements | ai.onnx(11-12, 13+) | gatherElements | ✗ | ✓ | | +| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 | | Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | | | Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input | | GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input | @@ -53,6 +58,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | +| LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | ✓ | ✓ | | | LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | @@ -60,11 +66,12 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Min | ai.onnx(7, 8-11, 12, 13+) | min | ✓ | ✓ | | | Mul | ai.onnx(7-12, 13, 14+) | mul | ✓ | ✓ | | | Neg | ai.onnx(7-12, 13+) | neg | ✓ | ✓ | | -| Not | ai.onnx(7+) | logicalnot | ✓ | ✓ | | +| Not | ai.onnx(7+) | logicalNot | ✓ | ✓ | | +| Or | ai.onnx(7+) | logicalOr | ✗ | ✓ | | | Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported | | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | -| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | ✗ | ✓ | | +| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | ✓ | ✓ | The shape of x_scale should be a subsample of the shape of input | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | | | ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant | @@ -78,13 +85,17 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | -| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | +| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | +| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' | +| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | +| SimplifiedLayerNormalization | ai.onnx(1+) | pow + reduceMean + add + sqrt + div + mul | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | +| Sign | ai.onnx(9-12, 13+) | sign | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | | Softsign | ai.onnx(7+) | softsign | ✓ | ✓ | | | Sin | ai.onnx(7+) | sin | ✓ | ✓ | | -| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 | +| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice, reverse | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant | | Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | | | Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant | | Sqrt | ai.onnx(7-12, 13+) | sqrt | ✓ | ✓ | | @@ -97,3 +108,4 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Where | ai.onnx(7-8, 9-15, 16+) | where | ✓ | ✓ | | +| Xor | ai.onnx(7+) | logicalXor | ✗ | ✓ | | diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 450ae2d06e638..475dfe0d4888b 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.20.0'; +export const version = '1.21.0'; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index bfb74355b0d70..a0010df4643a4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, + DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -134,6 +135,26 @@ class AdapterInfoImpl implements AdapterInfo { } } +class DeviceInfoImpl implements DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; + + constructor(device: GPUDevice) { + this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); + this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); + // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to + // workaround the IDL type checks. + // TODO: clean this after subgroups feature is settled in IDL. + const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; + if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { + this.subgroupSizeRange = undefined; + } else { + this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; + } + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -141,6 +162,7 @@ class AdapterInfoImpl implements AdapterInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; + deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -243,16 +265,22 @@ export class WebGpuBackend { requiredFeatures, }; - if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) { - requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName); - } else if (adapter.features.has('timestamp-query')) { - requiredFeatures.push('timestamp-query'); + // Try requiring WebGPU features + const requireFeatureIfAvailable = (feature: GPUFeatureName) => + adapter.features.has(feature) && requiredFeatures.push(feature) && true; + // Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query + if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) { + requireFeatureIfAvailable('timestamp-query'); } - if (adapter.features.has('shader-f16')) { - requiredFeatures.push('shader-f16'); + requireFeatureIfAvailable('shader-f16'); + // Try subgroups + if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { + // If subgroups feature is available, also try subgroups-f16 + requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); } this.device = await adapter.requestDevice(deviceDescriptor); + this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); @@ -902,6 +930,10 @@ export class WebGpuBackend { this.sessionStatus = 'default'; } + onCreateSession(): void { + this.gpuDataManager.onCreateSession(); + } + onReleaseSession(sessionId: number): void { this.unregisterBuffers(sessionId); if (this.capturedCommandList.has(sessionId)) { diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 685f3dc019461..b302354c46eeb 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -25,11 +25,31 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.uint32, 'uint32'], [DataType.int64, 'int64'], [DataType.uint64, 'uint64'], + [DataType.int4, 'int4'], + [DataType.uint4, 'uint4'], [DataType.int8, 'int8'], [DataType.uint8, 'uint8'], [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort() as Array; + const bKeys = Object.keys(b).sort() as Array; + return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]); +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -47,6 +67,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -67,6 +91,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -91,12 +150,16 @@ export class WebNNBackend { // Current session is not a WebNN session. return; } + this.tensorManager.releaseTensorsForSession(sessionId); this.mlContextBySessionId.delete(sessionId); const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); - this.tensorManager.releaseTensorsForContext(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } @@ -163,6 +226,71 @@ export class WebNNBackend { return id; } + // Register a WebNN Constant operand from external data. + public registerMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + mountedFiles: Map | undefined, + ): MLOperand { + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (!mountedFiles) { + throw new Error('External mounted files are not available.'); + } + + let filePath = externalFilePath; + if (externalFilePath.startsWith('./')) { + filePath = externalFilePath.substring(2); + } + const fileData = mountedFiles.get(filePath); + if (!fileData) { + throw new Error(`File with name ${filePath} not found in preloaded files.`); + } + + if (dataOffset + dataLength > fileData.byteLength) { + throw new Error('Out of bounds: data offset and length exceed the external file data size.'); + } + + const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer; + let bufferView: ArrayBufferView; + switch (desc.dataType) { + case 'float32': + bufferView = new Float32Array(buffer); + break; + case 'float16': + bufferView = new Uint16Array(buffer); + break; + case 'int32': + bufferView = new Int32Array(buffer); + break; + case 'uint32': + bufferView = new Uint32Array(buffer); + break; + case 'int64': + bufferView = new BigInt64Array(buffer); + break; + case 'uint64': + bufferView = new BigUint64Array(buffer); + break; + case 'int8': + bufferView = new Int8Array(buffer); + break; + case 'int4': + case 'uint4': + case 'uint8': + bufferView = new Uint8Array(buffer); + break; + default: + throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`); + } + + LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`); + + return builder.constant(desc, bufferView); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 7bce5ff9390e8..48bd3ef2bc36f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { + AdapterInfo, + ComputeContext, + ComputeContextInputsOutputsMapping, + DeviceInfo, + ProgramInfo, +} from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; + readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -87,42 +94,32 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; - const heapU32 = module.HEAPU32; + this.deviceInfo = backend.deviceInfo; // extract context data - let dataIndex = contextDataOffset >>> 2; - this.opKernelContext = heapU32[dataIndex++]; - const inputCount = heapU32[dataIndex++]; - this.outputCount = heapU32[dataIndex++]; - this.customDataOffset = heapU32[dataIndex++]; - this.customDataSize = heapU32[dataIndex++]; + const ptrSize = module.PTR_SIZE; + let dataIndex = contextDataOffset / module.PTR_SIZE; + const type = ptrSize === 4 ? 'i32' : 'i64'; + this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type)); + const inputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*')); + this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type)); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = heapU32[dataIndex++]; - const data = heapU32[dataIndex++]; - const dim = heapU32[dataIndex++]; + const dataType = Number(module.getValue(ptrSize * dataIndex++, type)); + const data = Number(module.getValue(ptrSize * dataIndex++, '*')); + const dim = Number(module.getValue(ptrSize * dataIndex++, type)); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(heapU32[dataIndex++]); + dims.push(Number(module.getValue(ptrSize * dataIndex++, type))); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } this.inputs = inputs; } - getMaxComputeWorkgroupSizes(): [number, number, number] { - return [ - this.backend.device.limits.maxComputeWorkgroupSizeX, - this.backend.device.limits.maxComputeWorkgroupSizeY, - this.backend.device.limits.maxComputeWorkgroupSizeZ, - ]; - } - - getMaxComputeWorkgroupStoragesize(): number { - return this.backend.device.limits.maxComputeWorkgroupStorageSize; - } - compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = @@ -152,11 +149,12 @@ class ComputeContextImpl implements ComputeContext { output(index: number, dims: readonly number[]): number { const stack = this.module.stackSave(); try { - const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */); - let offset = data >> 2; - this.module.HEAPU32[offset++] = dims.length; + const ptrSize = this.module.PTR_SIZE; + const type = ptrSize === 4 ? 'i32' : 'i64'; + const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); + this.module.setValue(data, dims.length, type); for (let i = 0; i < dims.length; i++) { - this.module.HEAPU32[offset++] = dims[i]; + this.module.setValue(data + ptrSize * (i + 1), dims[i], type); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { @@ -215,7 +213,7 @@ export const init = async ( backend, // jsepAlloc() - (size: number) => backend.alloc(size), + (size: number) => backend.alloc(Number(size)), // jsepFree() (ptr: number) => backend.free(ptr), @@ -223,12 +221,19 @@ export const init = async ( // jsepCopy(src, dst, size, isSourceGpu) (src: number, dst: number, size: number, isSourceGpu = false) => { if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); - backend.memcpy(src, dst); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + ); + backend.memcpy(Number(src), Number(dst)); } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); - backend.upload(dst, data); + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, + ); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); } }, @@ -239,12 +244,19 @@ export const init = async ( () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, ); - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); }, // jsepCreateKernel (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -256,8 +268,8 @@ export const init = async ( () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, ); - const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, errors); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); }, // jsepCaptureBegin () => backend.captureBegin(), diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 5ae16d5625dc8..85aca96057df2 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -167,7 +167,7 @@ export class ShapeUtil { 'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.', ); } - size *= dims[i]; + size *= Number(dims[i]); } return size; } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 33e8c95c141ee..1c6016500e7d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -64,6 +64,11 @@ export interface GpuDataManager { */ dispose(): void; + /** + * create session related data. + */ + onCreateSession(): void; + /** * release session related data. * @param sessionId - specify the session ID. @@ -112,7 +117,7 @@ const bucketArr: number[] = []; /** * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. */ -const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; +const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16; /** * calculate the buffer size so that it fits into buckets. @@ -186,8 +191,6 @@ class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) private storageCache: Map; - // pending buffers for uploading ( data is unmapped ) - private buffersForUploadingPending: GPUBuffer[]; // pending buffers for computing private buffersPending: GPUBuffer[]; @@ -200,11 +203,13 @@ class GpuDataManagerImpl implements GpuDataManager { // a SessionID -> GPUBuffer[] mapping. private capturedPendingBuffers: Map; + // The session count. + private sessionCount: number; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); - this.buffersForUploadingPending = []; this.buffersPending = []; this.capturedPendingBuffers = new Map(); @@ -213,6 +218,8 @@ class GpuDataManagerImpl implements GpuDataManager { this.freeBuffers.set(key, []); this.freeUniformBuffers.set(key, []); } + + this.sessionCount = 0; } upload(id: GpuDataId, data: Uint8Array): void { @@ -226,7 +233,7 @@ class GpuDataManagerImpl implements GpuDataManager { if (!gpuDataCache) { throw new Error('gpu data for uploading does not exist'); } - if (gpuDataCache.originalSize !== srcLength) { + if (Number(gpuDataCache.originalSize) !== srcLength) { throw new Error(`inconsistent data size. gpu data size=${gpuDataCache.originalSize}, data size=${srcLength}`); } @@ -242,13 +249,12 @@ class GpuDataManagerImpl implements GpuDataManager { gpuBufferForUploading.unmap(); // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); + const commandEncoder = this.backend.device.createCommandEncoder(); commandEncoder.copyBufferToBuffer(gpuBufferForUploading, 0, gpuDataCache.gpuData.buffer, 0, size); + this.backend.device.queue.submit([commandEncoder.finish()]); + gpuBufferForUploading.destroy(); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.upload(id=${id})`); - - this.buffersForUploadingPending.push(gpuBufferForUploading); } memcpy(sourceId: GpuDataId, destinationId: GpuDataId): void { @@ -288,9 +294,7 @@ class GpuDataManagerImpl implements GpuDataManager { LOG_DEBUG( 'verbose', () => - `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ - id - }, buffer is the same, skip.`, + `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, buffer is the same, skip.`, ); return id; } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { @@ -347,7 +351,7 @@ class GpuDataManagerImpl implements GpuDataManager { } const gpuData = { id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer }; - this.storageCache.set(gpuData.id, { gpuData, originalSize: size }); + this.storageCache.set(gpuData.id, { gpuData, originalSize: Number(size) }); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); return gpuData; @@ -357,10 +361,16 @@ class GpuDataManagerImpl implements GpuDataManager { return this.storageCache.get(id)?.gpuData; } - release(id: GpuDataId): number { + release(idInput: GpuDataId): number { + const id = typeof idInput === 'bigint' ? Number(idInput) : idInput; const cachedData = this.storageCache.get(id); if (!cachedData) { - throw new Error('releasing data does not exist'); + if (this.storageCache.size === 0) { + // cache was previously cleared, no need to release anything. + return 0; + } else { + throw new Error('releasing data does not exist'); + } } LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`); @@ -373,7 +383,7 @@ class GpuDataManagerImpl implements GpuDataManager { } async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { - const cachedData = this.storageCache.get(id); + const cachedData = this.storageCache.get(Number(id)); if (!cachedData) { throw new Error('data does not exist'); } @@ -381,12 +391,6 @@ class GpuDataManagerImpl implements GpuDataManager { } refreshPendingBuffers(): void { - for (const buffer of this.buffersForUploadingPending) { - // upload buffer is only useful in the session creation time. So we don't need to reuse them in session running. - buffer.destroy(); - } - this.buffersForUploadingPending = []; - if (this.buffersPending.length === 0) { return; } @@ -460,6 +464,10 @@ class GpuDataManagerImpl implements GpuDataManager { this.capturedPendingBuffers = new Map(); } + onCreateSession() { + this.sessionCount += 1; + } + onReleaseSession(sessionId: number) { // release the captured pending buffers. const pendingBuffers = this.capturedPendingBuffers.get(sessionId); @@ -469,6 +477,16 @@ class GpuDataManagerImpl implements GpuDataManager { }); this.capturedPendingBuffers.delete(sessionId); } + + // release the storage cache if no active sessions. + this.sessionCount -= 1; + if (this.sessionCount === 0) { + LOG_DEBUG('warning', () => '[WebGPU] Clearing webgpu buffer cache'); + this.storageCache.forEach((storage) => { + storage.gpuData.buffer.destroy(); + }); + this.storageCache = new Map(); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index fe824a5c4558a..6c7afbc7365bb 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,10 +16,12 @@ import { einsum, parseEinsumAttributes } from './ops/einsum'; import { expand } from './ops/expand'; import { fastGelu } from './ops/fast-gelu'; import { gather, parseGatherAttributes } from './ops/gather'; +import { gatherND, parseGatherNDAttributes } from './ops/gather-nd'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; -import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; +import { gridSample, parseGridSampleAttributes } from './ops/grid-sample'; +import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; import { matMul } from './ops/matmul'; @@ -29,6 +31,7 @@ import { pad } from './ops/pad'; import * as pool from './ops/pool'; import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear'; import { range } from './ops/range'; +import { scatterND, parseScatterNDAttributes } from './ops/scatter-nd'; import { reduceL1, reduceL2, @@ -98,13 +101,15 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]], + ['GatherND', [gatherND, parseGatherNDAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], + ['GridSample', [gridSample, parseGridSampleAttributes]], + ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], @@ -138,6 +143,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['RotaryEmbedding', [rotaryEmbedding]], + ['ScatterND', [scatterND, parseScatterNDAttributes]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3ef5c943d5624..9e21a552b8466 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -143,7 +143,21 @@ const conv2dCommonSnippet = ( } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`; - const sampleW = `${getWSnippet(innerElementSizeW)}`; + const sampleW = isChannelsLast + ? fitInner && fitBOuter + ? getWSnippet(innerElementSizeW) + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { + ${getWSnippet(innerElementSizeW)} + } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);` + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) { + ${getWSnippet(innerElementSizeW)} + } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`; const resType = typeSnippet(innerElementSize, dataType); const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 2a8756e435b8e..cb1f30ecdd1f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -29,229 +29,27 @@ import { ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType, + getMaxComponents, } from '../common'; import { ConvTransposeAttributes } from '../conv-transpose'; -const createConvTranspose2DOpProgramShaderSource = ( - shaderHelper: ShaderHelper, - inputs: readonly TensorView[], - outputShape: readonly number[], - hasBias: boolean, - is1DimensionDispatch: boolean, - isVec4 = false, - dataType: string, - uniforms: UniformsArrayType, - isChannelsLast = false, -): string => { - const rowDim = isChannelsLast ? 1 : 2; - const colDim = isChannelsLast ? 2 : 3; - const channelDim = isChannelsLast ? 3 : 1; - const workPerThread = isVec4 ? 2 : 1; - - let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { - result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); - }`; - if (hasBias) { - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; - } - const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); - const inputVariables = [dy, w]; - if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - - const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; - let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; - let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - - let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); - - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; - for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4<${dataType}>(0.0); - } - for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); - let wRPerm = uniforms.filter_dims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || - fract(dyR) > 0.0 || wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims[1] - 1 - wC; - if (wCPerm < 0) { - continue; - } - var bDyCVal = true; - var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC) > 0.0) { - bDyCVal = false; - } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC2) > 0.0) { - bDyCVal2 = false; - } - - let idyC: u32 = u32(dyC); - let idyC2: u32 = u32(dyC2); - if (bDyCVal && bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - - xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - - dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - } - } else if (bDyCVal) { - let d2Length = uniforms.Dy_shape[${channelDim}]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - } - } else if (bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[1] = dotProd[1] + tmpval; - } - } - } - } - - for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`}; - ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; - } - }`; - const codeSnippet = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; - let batch = ${output.indicesGet('outputIndices', 0)}; - let d1 = ${output.indicesGet('outputIndices', channelDim)}; - let r = ${output.indicesGet('outputIndices', rowDim)}; - let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; - let dyRCorner = dyCorner.x; - let dyCCorner = dyCorner.y; - let groupId = d1 / uniforms.output_channels_per_group; - let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { - if (wR % uniforms.dilations.x != 0) { - continue; - } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); - let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || - wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { - if (wC % uniforms.dilations.y != 0) { - continue; - } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || - fract(dyC) > 0.0 || wCPerm < 0) { - continue; - } - let idyC: u32 = u32(dyC); - var inputChannel = groupId * uniforms.input_channels_per_group; - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { - let xValue = ${ - isChannelsLast - ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; - let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; - dotProd = dotProd + xValue * wValue; - inputChannel = inputChannel + 1; - } - } - } - let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`}; - ${output.setByOffset('global_idx', 'value')}; - `; - - return ` - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} - ${declareFunctions} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; - ${isVec4 ? codeSnippet4 : codeSnippet}}`; -}; - export const createConvTranspose2DProgramInfo = ( inputs: readonly TensorView[], attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const hasBias = inputs.length > 2; - // const isChannelsLast = attributes.format === 'NHWC'; const outputShape = attributes.outputShape; - const outputSize = ShapeUtil.size(outputShape); - - // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // TODO Enable isVec4 for performance - // Disabled due to weight matrix layout issue - // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; + const isChannelsLast = attributes.format === 'NHWC'; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[2] / group; + const outputChannelsPerGroup = wShape[3]; + const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; + const outputSize = ShapeUtil.size(outputShape) / components; const dispatch = [Math.ceil(outputSize / 64), 1, 1]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const isChannelsLast = attributes.format === 'NHWC'; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const strides = [attributes.strides[0], attributes.strides[1]]; const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; @@ -268,15 +66,9 @@ export const createConvTranspose2DProgramInfo = ( ]; const pads = [ effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), - effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2, + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2), ]; - const isVec4 = false; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; - const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: strides }, @@ -294,7 +86,6 @@ export const createConvTranspose2DProgramInfo = ( } programUniforms.push(...createTensorShapeVariables(outputShape)); - const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, @@ -307,21 +98,83 @@ export const createConvTranspose2DProgramInfo = ( { name: 'output_channels_per_group', type: 'u32' }, ]; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return `${createConvTranspose2DOpProgramShaderSource( - shaderHelper, - inputs, - outputShape, - hasBias, - is1DimensionDispatch, - isVec4, - dataType, - uniforms, - isChannelsLast, - )}`; + const rowDim = isChannelsLast ? 1 : 2; + const colDim = isChannelsLast ? 2 : 3; + const channelDim = isChannelsLast ? 3 : 1; + + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + + const codeSnippet = ` + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; + let batch = ${output.indicesGet('outputIndices', 0)}; + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; + let dyRCorner = dyCorner.x; + let dyCCorner = dyCorner.y; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. + var dotProd = ${output.type.value}(0.0); + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { + continue; + } + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || + wRPerm < 0) { + continue; + } + let idyR: u32 = u32(dyR); + + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { + continue; + } + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || + fract(dyC) > 0.0 || wCPerm < 0) { + continue; + } + let idyC: u32 = u32(dyC); + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { + let xValue = ${ + isChannelsLast + ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${components}`)}; + dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; + } + } + } + let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''}; + ${output.setByOffset('global_idx', 'value')}; + `; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; + ${codeSnippet}}`; }; + return { name: 'ConvTranspose2D', - shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies }, + shaderCache: { hint: `${attributes.cacheKey};${components}`, inputDependencies }, getRunData: () => ({ dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, outputs: [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index f0287529ca08b..c6341f94cf191 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -25,7 +25,6 @@ import { ShapeUtil } from '../../../util'; import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; import { createTensorShapeVariables, - getBroadcastDims, IndicesHelper, inputVariable, internalVariable, @@ -40,6 +39,7 @@ import { getActivationSnippet, InternalActivationAttributes, } from '../fuse-utils'; +import { convertOutputBatchIndicesToInputBatchIndices } from '../matmul-shaders'; import { typeSnippet } from './activation_util'; @@ -373,42 +373,11 @@ const matMulReadWriteFnSource = ( hasBias: boolean, applyActivation: string, variables: IndicesHelper[], - batchShapes: Array, isChannelsLast = false, ): string => { - const [batchAShape, batchBShape, batchShape] = batchShapes; const [batchVariable, aVariable, bVariable, outputVariable] = variables; - const broadCastADims = getBroadcastDims(batchAShape, batchShape); - const broadCastBDims = getBroadcastDims(batchBShape, batchShape); const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); - const getAIndices = () => { - const aRank = aVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var aIndices: ${aVariable.type.indices};`; - for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastADims.forEach((i) => { - resStr += `\naIndices[${i}] = 0;`; - }); - resStr += `\naIndices[${aRank - 2}] = u32(row); - aIndices[${aRank - 1}] = u32(colIn);`; - return resStr; - }; - const getBIndices = () => { - const bRank = bVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var bIndices: ${bVariable.type.indices};`; - for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastBDims.forEach((i) => { - resStr += `\nbIndices[${i}] = 0;`; - }); - resStr += `\nbIndices[${bRank - 2}] = u32(row); - bIndices[${bRank - 1}] = u32(colIn);`; - return resStr; - }; + const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet( component, @@ -418,7 +387,16 @@ const matMulReadWriteFnSource = ( let col = colIn * ${component}; if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { - ${getAIndices()} + var aIndices: ${aVariable.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices( + 'aIndices', + aVariable, + aVariable.rank - 2, + batchVariable.rank, + 'batchIndices', + )} + ${aVariable.indicesSet('aIndices', aVariable.rank - 2, 'u32(row)')} + ${aVariable.indicesSet('aIndices', aVariable.rank - 1, 'u32(colIn)')} value = ${aVariable.getByIndices('aIndices')}; } return value; @@ -432,7 +410,16 @@ const matMulReadWriteFnSource = ( let col = colIn * ${component}; if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { - ${getBIndices()} + var bIndices: ${bVariable.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices( + 'bIndices', + bVariable, + bVariable.rank - 2, + batchVariable.rank, + 'batchIndices', + )} + ${bVariable.indicesSet('bIndices', bVariable.rank - 2, 'u32(row)')} + ${bVariable.indicesSet('bIndices', bVariable.rank - 1, 'u32(colIn)')} value = ${bVariable.getByIndices('bIndices')}; } return value; @@ -532,7 +519,6 @@ export const createMatmulProgramInfo = ( hasBias, applyActivation, [batchDims, A, B, output], - [outerDimsA, outerDimsB, outerDims], isChannelsLast, ); return ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 832f6e132901e..6a78c8ae3b190 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU import { getMaxComponents, + IndicesHelper, inputVariable, outputVariable, ShaderHelper, @@ -65,14 +66,17 @@ export interface AttentionParameters { broadcastResPosBias: boolean; passPastInKv: boolean; qkvFormat: AttentionQkvFormat; - isPastkvBSNH?: boolean; + softcap?: number; + doRotary?: number; + rotaryInterLeaved?: number; + sommoothSoftmax?: number; + localWindowsSize?: number; } export interface AttentionAttrs { numHeads: number; - kvNumHeads?: number; - isUnidirectional?: number; - maskFilterValue?: number; + isUnidirectional: number; + maskFilterValue: number; scale: number; doRotary: number; qkvHiddenSizes: number[]; @@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => { - const components = getMaxComponents(d); +const initVarStub = ( + seqLensInput: IndicesHelper | undefined, + totalSequenceLengthInput: IndicesHelper | undefined, + initPastSequenceLength: boolean, +) => { + // In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input + if (totalSequenceLengthInput && seqLensInput) { + return ` + let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')}); + let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length); + let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input; + let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input; + total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1; + var past_sequence_length: u32 = 0; + if (is_first_prompt == false) { + past_sequence_length = total_sequence_length - sequence_length; + } + `; + } else { + return ` + ${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''}; + let present_sequence_length = total_sequence_length; + `; + } +}; + +const createInPlaceSoftmaxProgramInfo = ( + input: TensorView, + batchSize: number, + numHeads: number, + pastSequenceLength: number, + sequenceLength: number, + totalSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, +) => { + // Set components to 1 if seqLens is specified, i.e. GroupQueryAttention. + const components = getMaxComponents(seqLens ? 1 : totalSequenceLength); let WG = 64; - const dComp = d / components; - if (dComp < WG) { + const totalSequenceLengthComp = totalSequenceLength / components; + if (totalSequenceLengthComp < WG) { WG = 32; } - const elementsPerThread = Math.ceil(d / components / WG); + const elementsPerThread = Math.ceil(totalSequenceLength / components / WG); const programUniforms: ProgramUniform[] = [ - { type: DataType.float, data: 1 / d }, - { type: DataType.uint32, data: dComp }, + { type: DataType.uint32, data: batchSize }, + { type: DataType.uint32, data: numHeads }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: totalSequenceLengthComp }, { type: DataType.uint32, data: elementsPerThread }, ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); + const inputHelpers = [inputHelper]; + const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputHelper) { + inputHelpers.push(seqLensInputHelper); + } + + const totalSequenceLengthInputHelper = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputHelper) { + inputHelpers.push(totalSequenceLengthInputHelper); + } const elemValueType = tensorTypeToWsglValueType(input.dataType); const uniforms: UniformsArrayType = [ - { name: 'd_inv', type: 'f32' }, - { name: 'd_comp', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + { name: 'total_sequence_length', type: 'u32' }, { name: 'elements_per_thread', type: 'u32' }, ]; return ` var thread_max: array; var thread_sum: array; - ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)} ${shaderHelper.mainStart([WG, 1, 1])} + let batchIdx = workgroup_id.z / uniforms.num_heads; + let headIdx = workgroup_id.z % uniforms.num_heads; + let sequence_length = uniforms.sequence_length; + var total_sequence_length = uniforms.total_sequence_length; + ${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; - + let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; + let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; var thread_max_vector = ${f32Type}(-3.402823e+38f); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } thread_max[local_idx] = ${(() => { @@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } var sum_vector = ${f32Type}(0); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } thread_sum[local_idx] = ${(() => { @@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } if (sum == 0) { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); } } else { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { var f32input = ${f32Type}(x[offset + i]); x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); } } + ${ + seqLens + ? ` + for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { + x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); + }` + : '' + }; }`; }; @@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number name: 'AttentionProbsSoftmax', shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, - getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), + getRunData: () => ({ + outputs: [], + dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + programUniforms, + }), }; }; @@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = ( pastKey: TensorView | undefined, attentionBias: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, pastSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey; + const presentKey = outputCount > 1 && pastKey; + const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads; const presentKeyShape = presentKey - ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] + ? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize] : undefined; - + const nReps = parameters.nReps ? parameters.nReps : 1; // TODO: handle mask - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale; const components = getMaxComponents(parameters.headSize); const vectorizedHeadSize = parameters.headSize / components; const TILE_SIZE = 12; @@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: vectorizedHeadSize }, { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.uint32, data: parameters.headSize }, { type: DataType.float, data: alpha }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; @@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; if (presentKey) { outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default }); @@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputVariable) { + inputVars.push(seqLensInputVariable); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputVariable) { + inputVars.push(totalSequenceLengthInputVariable); + } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; if (presentKey) { @@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'alpha', type: 'f32' as UniformDataElementType }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = ( ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} // x holds the N and y holds the M - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; + let batchIdx = workgroup_id.z / uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; - let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - ${(() => { - if (feedPastKey && presentKey) { - return ` - let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; - let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; - } else { - return ` - let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; - } - })()} - ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.N; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; + let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''}; + let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; + ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { @@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = ( ${(() => { if (feedPastKey && presentKey) { return ` - if (n + local_id.y < uniforms.past_sequence_length) { + if (n + local_id.y < past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; - } else { - tileK[idx] = - key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; + } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; }`; } else { - return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + return ` + if (n + local_id.y < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + }`; } })()} ${ - presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : '' + presentKey + ? `if (n + local_id.y < present_sequence_length) { + present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx]; + }` + : '' } } workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); + value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); } workgroupBarrier(); } - let headOffset = headIdx * uniforms.M * uniforms.N; - if (global_id.y < uniforms.M && global_id.x < uniforms.N) { + if (global_id.y < uniforms.M && global_id.x < total_sequence_length) { + let headOffset = workgroup_id.z * uniforms.M * uniforms.N; let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; var sum: f32 = ${(() => { switch (components) { @@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = ( pastValue: TensorView | undefined, params: AttentionParameters, pastSequenceLength: number, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue; + const presentValue = outputCount > 1 && pastValue; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; const presentValueShape = presentValue - ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] + ? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize] : undefined; const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; @@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: params.vHeadSize }, { type: DataType.uint32, data: params.numHeads }, + { type: DataType.uint32, data: params.headSize }, { type: DataType.uint32, data: repeatedVHiddenSize }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0; @@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; if (presentValue) { outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); @@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLens) { + inputVars.push(seqLensInputVariable!); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInput) { + inputVars.push(totalSequenceLengthInputVariable!); + } const output = outputVariable('output', probs.dataType, outputShape); const outputVars = [output]; if (presentValue) { @@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'v_hidden_size', type: 'u32' }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let batchIdx = workgroup_id.z / uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; let m = global_id.y; let n = global_id.x; - - let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; - ${(() => { - if (feedPastValue && presentValue) { - return ` - let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; - let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; - `; - } else { - return ` - let offsetB = headIdx * uniforms.N * uniforms.K + n; - `; - } - })()} - ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.K; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch + ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''}; + let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; + ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (m < uniforms.M && w + local_id.x < uniforms.K) { @@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = ( ${(() => { if (feedPastValue && presentValue) { return ` - if (w + local_id.y < uniforms.past_sequence_length) { - tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; - } else { - tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N]; + if (w + local_id.y < past_sequence_length) { + tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; + } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N]; } `; } else { return ` - tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N]; - `; + if (w + local_id.y < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N]; + }`; } })()} - ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''} + ${ + presentValue + ? ` + if (w + local_id.y < present_sequence_length) { + present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx]; + }` + : '' + } } workgroupBarrier(); - for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x]; + for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) { + value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x]; } workgroupBarrier(); } // we need to transpose output from BNSH_v to BSND_v - let batchIdx = workgroup_id.z / uniforms.num_heads; - let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads; if (m < uniforms.M && n < uniforms.N) { let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size - + currentBatchHeadNumber * uniforms.N + n; + + headIdx * uniforms.N + n; output[outputIdx] = value; } }`; @@ -671,23 +804,29 @@ export const applyAttention = ( pastValue: TensorView | undefined, attentionBiasInput: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { - // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); - const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; + const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const attentionBias = attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; const inputsK = [q, k]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { + if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { inputsK.push(pastKey); } if (attentionBias) { inputsK.push(attentionBias); } - + if (seqLens) { + inputsK.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsK.push(totalSequenceLengthInput); + } // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( @@ -697,31 +836,55 @@ export const applyAttention = ( pastKey, attentionBias, parameters, - attributes, pastSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] }, + { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] }, )[0]; // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( probs, - parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.batchSize, + parameters.numHeads, + pastSequenceLength, + parameters.sequenceLength, totalSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: [probs], outputs: [] }, + { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] }, ); - // Run AttrionScore + // Run AttentionScore const inputsV = [probs, v]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { + if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); } - context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), { - inputs: inputsV, - outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0], - }); + if (seqLens) { + inputsV.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsV.push(totalSequenceLengthInput); + } + context.compute( + createVxAttentionScoreProgramInfo( + outputCount, + probs, + v, + pastValue, + parameters, + pastSequenceLength, + seqLens, + totalSequenceLengthInput, + ), + { + inputs: inputsV, + outputs: outputCount > 1 ? [0, 2] : [0], + }, + ); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { @@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs): undefined, context.inputs[5], params, - attributes, ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 53c2ca2fa47d6..c695a71568c97 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -143,9 +143,11 @@ const createBinaryOpProgramInfo = ( additionalImplementation?: string, outputDataType: number = a.dataType, ): ProgramInfo => { - const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); - let outputShape = a.dims; - let outputSize = ShapeUtil.size(a.dims); + const aDims = a.dims.map((x) => Number(x) ?? 1); + const bDims = b.dims.map((x) => Number(x) ?? 1); + const isBroadcast = !ShapeUtil.areEqual(aDims, bDims); + let outputShape = aDims; + let outputSize = ShapeUtil.size(aDims); let vectorize = false; let sharedDimensionDivisibleBy4 = false; @@ -153,16 +155,16 @@ const createBinaryOpProgramInfo = ( // TODO: deal with zero-sized tensors (eg. dims=[1,0]) const cacheKeyAux = [isBroadcast]; if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + const calculatedShape = BroadcastUtil.calcShape(aDims, bDims, false); if (!calculatedShape) { throw new Error("Can't perform binary op on the given tensors"); } - outputShape = calculatedShape; + outputShape = calculatedShape.slice(); outputSize = ShapeUtil.size(outputShape); - const isAOneElement = ShapeUtil.size(a.dims) === 1; - const isBOneElement = ShapeUtil.size(b.dims) === 1; - const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; - const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; + const isAOneElement = ShapeUtil.size(aDims) === 1; + const isBOneElement = ShapeUtil.size(bDims) === 1; + const aLastDimDivisibleBy4 = aDims.length > 0 && aDims[aDims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = bDims.length > 0 && bDims[bDims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); cacheKeyAux.push(aLastDimDivisibleBy4); @@ -170,8 +172,8 @@ const createBinaryOpProgramInfo = ( // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { - const dimA = a.dims[a.dims.length - i] ?? 1; - const dimB = b.dims[b.dims.length - i] ?? 1; + const dimA = aDims[aDims.length - i]; + const dimB = bDims[bDims.length - i]; if (dimA === dimB) { sharedDimension *= dimA; } else { @@ -199,8 +201,8 @@ const createBinaryOpProgramInfo = ( getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, - a.dims, - b.dims, + aDims, + bDims, outputShape, vectorize, isBroadcast, @@ -216,7 +218,7 @@ const createBinaryOpProgramInfo = ( dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) }, programUniforms: [ { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) }, - ...createTensorShapeVariables(a.dims, b.dims, outputShape), + ...createTensorShapeVariables(aDims, bDims, outputShape), ], }), }; @@ -280,9 +282,7 @@ export const pow = (context: ComputeContext): void => { } else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) { return ${type}(pow(f32(a), f32(b))); // NaN } - return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${ - roundStr - }(pow(f32(abs(a)), f32(b)))); + return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${roundStr}(pow(f32(abs(a)), f32(b)))); } fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> { // TODO: implement vectorized pow diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c40229cde9e2b..0b9173403cd7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -195,7 +195,7 @@ export interface IndicesHelper { /** * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input' | 'output' | 'internal'; + readonly usage: 'input' | 'output' | 'atomicOutput' | 'internal'; /** * the rank of the input or output. @@ -219,7 +219,7 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s } // return type is [ storage type, runtime type ] or a single string for both - switch (type) { + switch (Number(type)) { case DataType.float16: return components > 1 ? `vec${components}` : 'f16'; case DataType.float: @@ -733,6 +733,20 @@ export const outputVariable = ( components: 1 | 2 | 3 | 4 = 1, ): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components); +/** + * Create a IndicesHelper for an atomic output. + * + * @param name - the name of the output. + * @param type - the tensor type of the output. + * @param shapeOrRank - the tensor shape or the rank of the output. + * @returns an IndicesHelper for the output. + */ +export const atomicOutputVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'atomicOutput', 1); + /** * Create a IndicesHelper for an internal variable. * @@ -868,6 +882,7 @@ class ShaderHelperImpl implements ShaderHelper { const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(local_invocation_id) local_id : vec3` : `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3, @@ -876,7 +891,6 @@ class ShaderHelperImpl implements ShaderHelper { @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? `let global_idx = global_id.x; - let local_idx = local_id.x; let workgroup_index = workgroup_id.x;` : `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x; @@ -905,9 +919,8 @@ class ShaderHelperImpl implements ShaderHelper { } this.variables.push(variable); this.appendVariableUniforms(variable); - const access = variable.usage === 'input' ? 'read' : 'read_write'; - const storageType = variable.type.storage; + const storageType = variable.usage === 'atomicOutput' ? `atomic` : variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } @@ -996,27 +1009,3 @@ class ShaderHelperImpl implements ShaderHelper { export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) => new ShaderHelperImpl(dispatchGroup, limits); - -/** - * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40 - * Returns the dimensions in the input shape that are broadcasted to - * produce the provided output shape. - * - * The returned dimensions are 0-indexed and sorted. An example: - * inShape = [4, 1, 3] - * outShape = [5, 4, 3, 3] - * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3. - */ -export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => { - const inRank = inShape.length; - const dims: number[] = []; - for (let i = 0; i < inRank; i++) { - const dim = inRank - 1 - i; - const a = inShape[dim] || 1; - const b = outShape[outShape.length - 1 - i] || 1; - if (b > 1 && a === 1) { - dims.unshift(dim); - } - } - return dims; -}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 236f1b09a6c93..3e168ddedac86 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -4,7 +4,6 @@ import { TensorView } from '../../tensor-view'; import { ComputeContext } from '../types'; -import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu'; import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu'; import { ConvAttributes } from './conv'; import { parseInternalActivationAttributes } from './fuse-utils'; @@ -227,41 +226,16 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose } }; -// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] -const weightTransposePerm = [2, 3, 1, 0]; - const convTranspose2d = ( context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): void => { - const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); - const isChannelsLast = attributes.format === 'NHWC'; - const outputShape = adjustedAttributes.outputShape; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's - // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit - // utilization rate is very low. - if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { - context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); - return; - } - const outHeight = outputShape[isChannelsLast ? 1 : 2]; - const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const weightHeight = inputs[1].dims[2]; - const weightWidth = inputs[1].dims[3]; - - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; - // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), { + context.compute(createTransposeProgramInfo(inputs[1], [2, 3, 0, 1]), { inputs: [1], outputs: [attributes.wIsConst ? -2 : -1], })[0]; @@ -271,29 +245,12 @@ const convTranspose2d = ( // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; - const hasBias = inputs.length === 3; - if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convTransposeInputs.push(inputs[2]); - } + if (inputs.length === 3) { + convTransposeInputs.push(inputs[2]); } - - // STEP.3: compute matmul - context.compute( - createConv2DTransposeMatMulProgramInfo( - convTransposeInputs, - adjustedAttributes, - outputShape, - dimAOuter, - dimBOuter, - dimInner, - hasBias, - sequentialAccessByThreads, - ), - { inputs: convTransposeInputs }, - ); + context.compute(createConvTranspose2DProgramInfo(convTransposeInputs, attributes, squeezeOutputShapeFunction), { + inputs: convTransposeInputs, + }); }; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { @@ -338,12 +295,9 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri { ...attributes, pads, strides, dilations, kernelShape }, inputs, ); - context.compute( - createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) => - isChannelLast - ? [outputShape[0], outputShape[2], outputShape[3]] - : [outputShape[0], outputShape[1], outputShape[3]], - ), + + convTranspose2d(context, inputs, adjustedAttributes, (outputShape) => + isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]], ); }; @@ -352,6 +306,7 @@ export const convTranspose = (context: ComputeContext, attributes: ConvTranspose if (context.inputs[0].dims.length === 3) { convTranspose1d(context, attributes); } else { - convTranspose2d(context, context.inputs, attributes); + const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, context.inputs); + convTranspose2d(context, context.inputs, adjustedAttributes); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index de9f7bc8885ab..f9225baf66eea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -11,7 +11,7 @@ import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/con import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped'; import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; -import { createNaiveMatmulProgramInfo } from './matmul'; +import { createNaiveMatmulProgramInfo } from './matmul-shaders'; import { createTransposeProgramInfo } from './transpose'; export const calculateOutputShape = ( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 4e2bfa9d89924..3691b5ecb602b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -48,11 +48,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); const dataType = inputs[0].dataType; - const components = dataType === DataType.bool ? 4 : 1; + const isBoolOrScalar = dataType === DataType.bool || ShapeUtil.size(inputShape) === 1; + const iComponents = + dataType === DataType.bool ? 4 : inputShape.length > 0 && inputShape[inputShape.length - 1] % 4 === 0 ? 4 : 1; + const components = isBoolOrScalar + ? 4 + : outputShape.length > 0 && outputShape[outputShape.length - 1] % 4 === 0 + ? 4 + : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const getShaderSource = (shaderHelper: ShaderHelper) => { - const input = inputVariable('input', dataType, inputShape.length, components); + const input = inputVariable('input', dataType, inputShape.length, iComponents); const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { @@ -74,9 +81,10 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }`; } else { assignment = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)}; - ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))} + let data = ${output.type.value}(${input.getByOffset(`inputOffset / ${iComponents}`)}); + ${output.setByOffset('global_idx', 'data')} }`; } return ` @@ -92,7 +100,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ]; return { name: 'Expand', - shaderCache: { hint: `${outputShape.length}`, inputDependencies: ['rank'] }, + shaderCache: { hint: `${outputShape.length};${iComponents}${components}`, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts new file mode 100644 index 0000000000000..43b51f6e94a66 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; + +export interface GatherNDAttributes extends AttributeWithCacheKey { + readonly batchDims: number; +} + +const computeSliceOffsets = ( + context: ComputeContext, + indicesData: TensorView, + sizesFromSliceDimsData: number[], + batchDims: number, + inputDims: readonly number[], + numSlices: number, + numSlicesPerBatch: number, + inputBatchStride: number, + numSliceDims: number, +) => { + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: numSlices }, + { type: DataType.uint32, data: batchDims }, + { type: DataType.uint32, data: inputDims }, + { type: DataType.uint32, data: sizesFromSliceDimsData }, + { type: DataType.uint32, data: numSlicesPerBatch }, + { type: DataType.uint32, data: inputBatchStride }, + { type: DataType.uint32, data: numSliceDims }, + ]; + + const outputShape = [numSlices]; + programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length); + const output = outputVariable('input_slice_offsets_data', DataType.uint32, 1, 1); + const variables = [indices, output]; + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'batch_dims', type: 'u32' }, + { name: 'input_dims', type: 'u32', length: inputDims.length }, + { name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length }, + { name: 'num_slices_per_batch', type: 'u32' }, + { name: 'input_batch_stride', type: 'u32' }, + { name: 'num_slice_dims', type: 'u32' }, + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let batch_idx = global_idx / uniforms.num_slices_per_batch; + let base_offset = batch_idx * uniforms.input_batch_stride; + + let slice_indices_base_offset = global_idx * uniforms.num_slice_dims; + var relative_slice_offset = 0; + for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) { + var index = i32(indices_data[dim_idx + slice_indices_base_offset].x); + let input_dim_idx = uniforms.batch_dims + dim_idx; + if (index < 0) { + ${ + inputDims.length === 1 + ? 'index += i32(uniforms.input_dims);' + : 'index += i32(uniforms.input_dims[input_dim_idx]);' + } + } + ${ + sizesFromSliceDimsData.length === 1 + ? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);' + : 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);' + } + } + + input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset); + }`; + }; + + return context.compute( + { + name: 'computeSliceOffsets', + shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }], + dispatchGroup: { x: Math.ceil(numSlices / 64) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [indicesData], outputs: [-1] }, + )[0]; +}; + +export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => { + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const inputType = inputs[0].dataType; + const indicesShape = inputs[1].dims; + const numSliceDims = indicesShape[indicesShape.length - 1]; + const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1); + const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims); + const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims); + const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims); + const numSlicesPerBatch = numSlices / numBatches; + const sizesFromSliceDims = new Array(numSliceDims); + let runningProduct = sliceSize; + for (let i = 0; i < numSliceDims; ++i) { + sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct; + runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i]; + } + + const inputSliceOffsets = computeSliceOffsets( + context, + inputs[1], + sizesFromSliceDims, + attributes.batchDims, + inputShape, + numSlices, + numSlicesPerBatch, + inputBatchStride, + numSliceDims, + ); + + const lastIndicesDimension = attributes.batchDims + numSliceDims; + if (lastIndicesDimension > inputShape.length) { + throw new Error('last dimension of indices must not be larger than rank of input tensor'); + } + + const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension)); + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: sliceSize }, + ...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length); + const indices = inputVariable('slice_offsets', DataType.uint32, inputSliceOffsets.dims.length); + + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('slice_size', 'u32') + .declareVariables(input, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let slice_offset = slice_offsets[global_idx / uniforms.slice_size]; + output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size]; + }`; + }; + context.compute( + { + name: 'GatherND', + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [inputs[0], inputSliceOffsets] }, + ); +}; + +export const parseGatherNDAttributes = (attributes: Record): GatherNDAttributes => { + const batchDims = attributes.batch_dims as number; + return { + batchDims, + cacheKey: '', + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 7f2469d95e1c1..09365f3b984b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -55,9 +55,15 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt if (!outputShape) { throw new Error("Can't use gemm on the given tensors"); } + const tileSize = 16; + const numTileN = Math.ceil(N / tileSize); + const numTileM = Math.ceil(M / tileSize); + // TODO: Find the condition when to use the naive one. + const useShared = true; + const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: useShared ? numTileN : outputSize }, { type: DataType.uint32, data: M }, { type: DataType.uint32, data: N }, { type: DataType.uint32, data: K }, @@ -130,6 +136,159 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt }`; }; + const getShaderSourceShared = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); + const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); + let c: IndicesHelper | null = null; + const variables = [a, b]; + if (inputs.length === 3) { + c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); + variables.push(c); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + variables.push(output); + const uniforms: UniformsArrayType = [ + { name: 'num_tile_n', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'alpha', type: 'f32' }, + { name: 'beta', type: 'f32' }, + ]; + + let calcResult = ''; + let fillWorkgroupMemory = ''; + if (attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[local_id.x][k];`; + } else if (attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[k][local_id.x];`; + } else if (!attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[local_id.x][k];`; + } else if (!attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[k][local_id.x];`; + } + + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;'; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + var tile_a: array, ${tileSize}>; + var tile_b: array, ${tileSize}>; + ${shaderHelper.mainStart([tileSize, tileSize, 1])} + let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${tileSize}; + let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${tileSize}; + let num_tiles = (uniforms.K - 1) / ${tileSize} + 1; + var k_start = 0u; + var value = ${output.type.value}(0); + for (var t: u32 = 0u; t < num_tiles; t++) { + ${fillWorkgroupMemory} + k_start = k_start + ${tileSize}; + workgroupBarrier(); + + for (var k: u32 = 0u; k < ${tileSize}; k++) { + ${calcResult} + } + workgroupBarrier(); + } + + ${calculateAlpha} + let m = tile_row_start + local_id.y; + let n = tile_col_start + local_id.x; + ${(() => { + if (c != null) { + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + output.type.value + }(uniforms.beta) * ${c.getByOffset('cOffset')};`; + } + return ''; + })()} + if (m < uniforms.M && n < uniforms.N) { + output[m * uniforms.N + n] = value; + } + }`; + }; + + if (useShared) { + return { + name: 'GemmShared', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: numTileN * numTileM }, + programUniforms, + }), + getShaderSource: getShaderSourceShared, + }; + } + return { name: 'Gemm', shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts new file mode 100644 index 0000000000000..50c71472434ad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; + +let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW +type Mode = 'bilinear' | 'nearest' | 'bicubic'; +type PaddingMode = 'zeros' | 'border' | 'reflection'; +type Format = 'NHWC' | 'NCHW'; +export interface GridSampeAttributes extends AttributeWithCacheKey { + alignCorners: number; + mode: Mode; + paddingMode: PaddingMode; + format: Format; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 4) { + throw new Error('only 4-D tensor is supported.'); + } + if (inputs[0].dims.length !== inputs[1].dims.length) { + throw new Error('input dimensions must be equal to grid dimensions'); + } + + if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) { + throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`); + } + + if (inputs[0].dims[0] !== inputs[1].dims[0]) { + throw new Error('grid batch size must match input batch size'); + } +}; + +const gsGetCubicCoeffs = ` + fn gs_get_cubic_coeffs(x: f32) -> vec4 { + let cubic_alpha = -0.75f; + let x_abs = abs(x); + var coeffs: vec4; + coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha); + coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1); + coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1); + coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha); + return coeffs; + } +`; + +const gsBicubicInterpolate = (dataType: string): string => ` + fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} { + var v: vec4; + var coeffs = gs_get_cubic_coeffs(x); + for (var i = 0; i < 4; i++) { + v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; + } + coeffs = gs_get_cubic_coeffs(y); + let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); + return pixel; + } +`; + +const gsDenormalize = (attributes: GridSampeAttributes): string => ` + fn gs_denormalize(n: f32, length: i32) -> f32 { + ${ + attributes.alignCorners === 0 + ? ` + // alignCorners: false => [-1, 1] to [-0.5, length - 0.5] + return ((n + 1.0) * f32(length) - 1.0) / 2.0; + ` + : ` + // alignCorners: true => [-1, 1] to [0, length - 1] + return (n + 1.0) / 2.0 * (f32(length - 1)); + ` + } + } +`; + +const gsReflect = (attributes: GridSampeAttributes): string => ` + ${ + attributes.paddingMode === 'reflection' + ? ` + fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 { + var dx = 0.0; + var fx = f32(x); + let range = x_max - x_min; + if (fx < x_min) { + dx = x_min - fx; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_min + r; + } else { + fx = x_max - r; + } + } else if (fx > x_max) { + dx = fx - x_max; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_max - r; + } else { + fx = x_min + r; + } + } + return u32(fx); + }` + : '' + } +`; + +const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + ` + fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4) -> ${dataType} { + var pixel = ${dataType}(0); + var indices = vec4(0); + indices[${idxN}] = batch; + indices[${idxC}] = channel;` + + (() => { + switch (attributes.paddingMode) { + case 'zeros': + return ` + if (r >= 0 && r < H && c >=0 && c < W) { + indices[${idxH}] = u32(r); + indices[${idxW}] = u32(c); + } + `; + case 'border': + return ` + indices[${idxH}] = u32(clamp(r, 0, H - 1)); + indices[${idxW}] = u32(clamp(c, 0, W - 1)); + `; + case 'reflection': + return ` + indices[${idxH}] = gs_reflect(r, border[1], border[3]); + indices[${idxW}] = gs_reflect(c, border[0], border[2]); + `; + default: + throw new Error(`padding mode ${attributes.paddingMode} is not supported`); + } + })() + + ` + return ${input.getByIndices('indices')}; + } +`; + +const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + (() => { + switch (attributes.mode) { + case 'nearest': + return ` + let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border); + `; + case 'bilinear': + return ` + let x1 = i32(floor(x)); + let y1 = i32(floor(y)); + let x2 = x1 + 1; + let y2 = y1 + 1; + + let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + + let dx2 = ${dataType}(f32(x2) - x); + let dx1 = ${dataType}(x - f32(x1)); + let dy2 = ${dataType}(f32(y2) - y); + let dy1 = ${dataType}(y - f32(y1)); + let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + `; + case 'bicubic': + return ` + let x0 = i32(floor(x)) - 1; + let y0 = i32(floor(y)) - 1; + var p: mat4x4<${dataType}>; + for (var h = 0; h < 4; h++) { + for (var w = 0; w < 4; w++) { + p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + } + } + + let dx = x - f32(x0 + 1); + let dy = y - f32(y0 + 1); + let result = gs_bicubic_interpolate(p, dx, dy); + `; + default: + throw new Error(`mode ${attributes.mode} is not supported`); + } + })() + `${output.setByOffset('global_idx', 'result')}`; + +const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length); + // discard last dimension for using vec2 to access grid data + const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]]; + const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2); + let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]]; + if (attributes.format === 'NHWC') { + outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]]; + [idxN, idxC, idxH, idxW] = [0, 3, 1, 2]; + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const dataType = x.type.value; + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)} + ${gsGetCubicCoeffs} + ${gsBicubicInterpolate(dataType)} + ${gsDenormalize(attributes)} + ${gsReflect(attributes)} + ${pixelAtGrid(x, dataType, attributes)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let H_in = i32(uniforms.x_shape[${idxH}]); + let W_in = i32(uniforms.x_shape[${idxW}]); + + ${ + attributes.alignCorners === 0 + ? ` + let x_min = -0.5; + let x_max = f32(W_in) - 0.5; + let y_min = -0.5; + let y_max = f32(H_in) - 0.5; + ` + : ` + let x_min = 0.0; + let x_max = f32(W_in) - 1.0; + let y_min = 0.0; + let y_max = f32(H_in) - 1.0; + ` + }; + let border = vec4(x_min, y_min, x_max, y_max); + + let indices = ${output.offsetToIndices('global_idx')}; + var grid_indices = vec3(indices[${idxN}], indices[${idxH}], indices[${idxW}]); + let nxy = ${grid.getByIndices('grid_indices')}; + var x = gs_denormalize(f32(nxy[0]), W_in); + var y = gs_denormalize(f32(nxy[1]), H_in); + + ${computePixel(output, dataType, attributes)} + }`; + + return { + name: 'GridSample', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] }, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }; + }, + getShaderSource, + }; +}; + +export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => { + validateInputs(context.inputs); + context.compute(createGridSampleProgramInfo(context.inputs, attributes)); +}; + +export const parseGridSampleAttributes = (attributes: Record): GridSampeAttributes => + createAttributeWithCacheKey({ + alignCorners: attributes.align_corners as number, + mode: attributes.mode as Mode, + paddingMode: attributes.padding_mode as PaddingMode, + format: attributes.format as Format, + }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 56291c037b7da..bbe25460d6fd3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -1,31 +1,49 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { ShapeUtil } from '../../util'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; -import { - applyAttention, - AttentionAttrs, - AttentionMaskType, - AttentionParameters, - AttentionQkvFormat, -} from './attention'; -import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; -import { createTileProgramInfo } from './tile'; +import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; - -export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { +export interface GroupQueryAttentionAttributes { + numHeads: number; + kvNumHeads: number; + scale: number; + softcap: number; + doRotary: number; + rotaryInterleaved: number; + smoothSoftmax: boolean; + localWindowSize: number; +} + +export const validateInputs = ( + inputs: readonly TensorView[], + attributes: GroupQueryAttentionAttributes, +): AttentionParameters => { + if (attributes.doRotary && inputs.length <= 7) { + throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); + } const query = inputs[0]; const key = inputs[1]; const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; - + if (attributes.localWindowSize !== -1) { + throw new Error('Local attention is not supported'); + } + if (attributes.softcap !== 0) { + throw new Error('Softcap is not supported'); + } + if (attributes.rotaryInterleaved !== 0) { + throw new Error('Rotary interleaved is not supported'); + } + if (attributes.smoothSoftmax) { + throw new Error('Smooth softmax is not supported'); + } // Abbreviation and Meanings: // B: batch_size // S: sequence_length (input sequence length of query) @@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = + let hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; - let maxSequenceLength = 0; - const headSize = Math.floor(hiddenSize / attributes.numHeads); + const packedQKV = !key || key.dims.length === 0; + const headSize = !packedQKV + ? Math.floor(hiddenSize / attributes.numHeads) + : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads)); + if (packedQKV) { + hiddenSize = headSize * attributes.numHeads; + } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; - // TODO : this should be from attributes. - const isPastkvBSNH = true; + // Currenly the onnxruntime GQA specification only support key/value BNSH format. + const isPastkvBSNH = + hasPastKey && + pastKey.dims.length === 4 && + pastKey.dims[0] === batchSize && + pastKey.dims[1] !== attributes.kvNumHeads && + pastKey.dims[2] === attributes.kvNumHeads && + pastKey.dims[3] === headSize; + + if (isPastkvBSNH) { + throw new Error('BSNH pastKey/pastValue is not supported'); + } if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); @@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (pastValue.dims.length !== 4) { throw new Error('Input "past_value" is expected to have 4 dimensions'); } - if (isPastkvBSNH) { - // For BSNH - pastSequenceLength = pastKey.dims[1]; - maxSequenceLength = pastKey.dims[1]; - } else { - // For BNSH - pastSequenceLength = pastKey.dims[2]; - maxSequenceLength = pastKey.dims[2]; - } + pastSequenceLength = pastKey.dims[2]; } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } - let qkvFormat: AttentionQkvFormat; - if (key) { + let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH; + if (key && key.dims.length > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } @@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (query.dims[2] % key.dims[2] !== 0) { throw new Error('Dimension 2 of "query" should be a multiple of "key"'); } - qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { @@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (value) { throw new Error('Expect "value" be none when "key" has packed kv format.'); } - qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - - qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } } else { @@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const maskType: AttentionMaskType = AttentionMaskType.none; let passPastInKv = false; - let vHiddenSize = hiddenSize; - if (value) { + let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize; + if (value && value.dims.length > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } @@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const seqlLens = inputs.length > 4 ? inputs[5] : undefined; + if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { + throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); + } + const totalSequenceLength = -1; + const maxSequenceLength = -1; const broadcastResPosBias = false; return { @@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent hiddenSize, vHiddenSize, headSize, - vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!), + vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads), numHeads: attributes.numHeads, kvNumHeads: attributes.kvNumHeads, - nReps: attributes.numHeads / attributes.kvNumHeads!, + nReps: attributes.numHeads / attributes.kvNumHeads, pastPresentShareBuffer: false, maskType, scale: attributes.scale, broadcastResPosBias, passPastInKv, qkvFormat, - isPastkvBSNH, }; }; -const createConcatProgramInfo = ( - a: TensorView, - b: TensorView | undefined, - dataType: DataType, - params: AttentionParameters, -): ProgramInfo => { - const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; - const component = 4; - const outputSize = ShapeUtil.size(outputShape) / component; - const presentSequenceLength = params.totalSequenceLength; - const output = outputVariable('present_kv', dataType, outputShape.length, component); - const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); - const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; - - const H = Math.ceil(params.headSize / component); - const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 }; - - const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: params.pastSequenceLength }, - { type: DataType.uint32, data: params.kvSequenceLength }, - { type: DataType.uint32, data: params.totalSequenceLength }, - ]; - - const inputs = [inputA]; - if (inputB) { - programUniforms.push( - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b!.dims), - ...createTensorShapeVariables(outputShape), - ); - inputs.push(inputB); - } else { - programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); - } - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'past_seqlen', type: 'u32' }, - { name: 'new_seqlen', type: 'u32' }, - { name: 'present_seqlen', type: 'u32' }, - ]; - - const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; - var past_head_stride = uniforms.past_seqlen * H; - if (is_bsnh) { - past_head_stride = H; - } - let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset];`; - const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; - let new_row_stride = num_heads * H; - let new_head_stride = H; - let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset];`; - const concatStr = b - ? `if (s < past_seqlen) { - ${pastStr} - } else if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }` - : `if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }`; - - // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} - ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var indices = ${output.offsetToIndices('global_idx')}; - let h = local_id.x; - let n = local_id.y; - let s = workgroup_id.x; - let b = workgroup_id.y; - let num_heads = ${params.kvNumHeads!}u; - let H = ${H}u; - - let present_seqlen = uniforms.present_seqlen; - let present_batch_stride = present_seqlen * num_heads * H; - var row_stride = H; - let is_bsnh = ${params.isPastkvBSNH}; - - if (is_bsnh) { - row_stride = num_heads * H; - } - var present_head_stride = present_seqlen * H; - if (is_bsnh) { - present_head_stride = H; - } - - let past_seqlen = uniforms.past_seqlen; - - let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - ${concatStr} - }`; - - return { - name: 'ConcatPastNew', - shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType }], - dispatchGroup: dispatch, - programUniforms, - }), - getShaderSource, - }; -}; - -export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({ ...attributes }); - const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); -const maybeExpandAndTransposeToBNSH = ( - context: ComputeContext, - input: TensorView, - pastKV: TensorView | undefined, - params: AttentionParameters, - outputIndex: number, -) => { +const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => { let reshapedInput = input; const numHeads = params.kvNumHeads!; - const nReps = params.nReps!; if (input.dims.length === 3 && params.kvSequenceLength !== 0) { reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); - } - - if (pastKV) { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), { - inputs: [reshapedInput, pastKV], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } else { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), { - inputs: [reshapedInput], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } - if (nReps !== 1) { - reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), { + reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], })[0]; - reshapedInput = reshapedInput.reshape([ - params.batchSize, - params.totalSequenceLength, - numHeads * nReps, - params.headSize, - ]); } - return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { - inputs: [reshapedInput], - outputs: [-1], - })[0]; + return reshapedInput; }; -export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { +export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); @@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti throw new Error('Packed KV is not implemented'); } + const q = context.inputs[0]; + const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; + const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; + const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; + const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; + const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; + const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; + + // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. + + const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ + axis: 2, + numOutputs: 3, + splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], + }); + const [query, key, value] = + !k && !v + ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) + : [q, k!, v!]; + const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, - context.inputs[0], + query, undefined, 0, ); - const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; - const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; - const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); - const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2); - applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); + applyAttention( + context, + Q, + maybeTransposeToBNSH(context, key, params), + maybeTransposeToBNSH(context, value, params), + undefined, + undefined, + pastKey, + pastValue, + undefined, + params, + seqLens, + totalSequenceLengthInput, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 859bd850862aa..a357d29667319 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -36,7 +36,10 @@ const computeChannelScaleShift = ( const f32Type = components === 1 ? 'f32' : `vec${components}f`; const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; const unitsOfWork = n * c; - + let workgroupSize = 64; + if (unitsOfWork === 1) { + workgroupSize = 256; + } const inputShape = [n, c, h / components]; const outputShape = [n, c, 2]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; @@ -49,7 +52,6 @@ const computeChannelScaleShift = ( const b = inputVariable('bias', bias.dataType, bias.dims); const output = outputVariable('output', DataType.float, 3, 2); const variables = [x, s, b, output]; - const workgroupSize = 64; return ` var workgroup_shared : array<${wgType}, ${workgroupSize}>; const workgroup_size = ${workgroupSize}u; @@ -91,7 +93,7 @@ const computeChannelScaleShift = ( { name: 'InstanceNormComputeChannelScaleShift', // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: DataType.float }], dispatchGroup: { x: unitsOfWork }, @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = ( const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // 1. transpose x from NHWC to NCHW + let needTranspose = false; const transposedXPerm = [0, xShape.length - 1]; for (let i = 0; i < xShape.length - 2; i++) { + needTranspose = needTranspose || xShape[i + 1] !== 1; transposedXPerm.push(i + 1); } - const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { - inputs: [context.inputs[0]], - outputs: [-1], - })[0]; + + needTranspose = needTranspose && xShape[xShape.length - 1] !== 1; + + const transposedX = needTranspose + ? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0] + : context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]])); // 2. compute channel scale and channel shift. const channelScaleShift = computeChannelScaleShift( context, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts new file mode 100644 index 0000000000000..e1f73f137e43e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ProgramInfo, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + getMaxComponents, + IndicesHelper, + inputVariable, + internalVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; +import { + appendActivationUniforms, + appendActivationUniformsData, + getActivationSnippet, + InternalActivationAttributes, +} from './fuse-utils'; + +// Helper that convert output batch indices to input batch indices using only the rank and +// the shape information in uniform +export const convertOutputBatchIndicesToInputBatchIndices = ( + targetIndicesName: string, + inputVariable: IndicesHelper, + inputBatchRank: number, + outputBatchRank: number, + batchIndicesName: string, +) => { + // Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of + // outputBatchRank should be ignored. + const extendingInputRank = outputBatchRank - inputBatchRank; + return ` + ${Array.from({ length: inputBatchRank }) + .map( + (_, i) => ` + if (${getElementAt(inputVariable.shape, i, inputVariable.rank)} != 1) { + ${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))} + } else { + ${inputVariable.indicesSet(targetIndicesName, i, 0)} + }`, + ) + .join('')} +`; +}; + +export const createNaiveMatmulProgramInfo = ( + inputs: readonly TensorView[], + activationAttributes: InternalActivationAttributes, + outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], +): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + + const M = aShape[aShape.length - 2]; + const N = bShape[bShape.length - 1]; + const K = aShape[aShape.length - 1]; + const components = getMaxComponents(N); + const aComponents = getMaxComponents(K); + const outputNumber = getMaxComponents(M); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const hasBias = inputs.length > 2; + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: M }, + { type: DataType.uint32, data: N }, + { type: DataType.uint32, data: K }, + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + processBias = `${ + isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);` + }`; + } + + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + ]; + appendActivationUniforms(activationAttributes, uniforms); + + const calcResult = (): string => { + let calcStr = `var a_data: ${a.type.value};`; + for (let i = 0; i < aComponents; i++) { + calcStr += ` + let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; + } + for (let i = 0; i < outputNumber; i++) { + calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; + + for (let j = 0; j < aComponents; j++) { + calcStr += ` + values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`; + } + } + return calcStr; + }; + + return ` + ${shaderHelper + .registerUniforms(uniforms) + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let col = (global_idx % (uniforms.N / ${components})) * ${components}; + var index1 = global_idx / (uniforms.N / ${components}); + let stride1 = uniforms.M / ${outputNumber}; + let row = (index1 % stride1) * ${outputNumber}; + let batch = index1 / stride1; + + ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} + + var a_indices: ${a.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices('a_indices', a, a.rank - 2, batchDims.rank, 'batch_indices')} + ${a.indicesSet('a_indices', a.rank - 2, 0)} + ${a.indicesSet('a_indices', a.rank - 1, 0)} + let a_offset = ${a.indicesToOffset('a_indices')}; + + var b_indices: ${b.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices('b_indices', b, b.rank - 2, batchDims.rank, 'batch_indices')} + ${b.indicesSet('b_indices', b.rank - 2, 0)} + ${b.indicesSet('b_indices', b.rank - 1, 0)} + let b_offset = ${b.indicesToOffset('b_indices')}; + var values: array<${output.type.value}, ${outputNumber}>; + for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { + ${calcResult()} + } + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + let cur_indices = ${output.type.indices}(batch, row + i, col); + let offset = ${output.indicesToOffset('cur_indices')}; + ${output.setByOffset(`offset / ${components}`, 'value')}; + } + } + `; + }; + return { + name: 'MatMulNaive', + shaderCache: { + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 7605e67c972b9..46a358aacdad4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,184 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { BroadcastUtil, ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; +import { createNaiveMatmulProgramInfo } from './matmul-shaders'; import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; -import { - createTensorShapeVariables, - getBroadcastDims, - getMaxComponents, - IndicesHelper, - inputVariable, - internalVariable, - outputVariable, - ShaderHelper, - tensorTypeToWsglStorageType, - UniformsArrayType, -} from './common'; -import { - appendActivationUniforms, - appendActivationUniformsData, - getActivationSnippet, - InternalActivationAttributes, -} from './fuse-utils'; - -export const createNaiveMatmulProgramInfo = ( - inputs: readonly TensorView[], - activationAttributes: InternalActivationAttributes, - outputShape: readonly number[], - reshapedOutputShape?: readonly number[], - isChannelsLast = false /* only used for conv2dByMatMul*/, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], -): ProgramInfo => { - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; - - const M = aShape[aShape.length - 2]; - const N = bShape[bShape.length - 1]; - const K = aShape[aShape.length - 1]; - const components = getMaxComponents(N); - const aComponents = getMaxComponents(K); - const outputNumber = getMaxComponents(M); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const hasBias = inputs.length > 2; - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchSize = ShapeUtil.size(outerDims); - const outputShapeInShader = [batchSize, M, N]; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: M }, - { type: DataType.uint32, data: N }, - { type: DataType.uint32, data: K }, - ]; - appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); - const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); - const b = inputVariable('b', inputs[1].dataType, bShape.length, components); - const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); - const inputVariables = [a, b]; - let processBias = ''; - if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); - processBias = `${ - isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);` - }`; - } - - const outerDimsA = aShape.slice(0, -2); - const outerDimsB = bShape.slice(0, -2); - const broadCastADims = getBroadcastDims(outerDimsA, outerDims); - const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'M', type: 'u32' }, - { name: 'N', type: 'u32' }, - { name: 'K', type: 'u32' }, - ]; - appendActivationUniforms(activationAttributes, uniforms); - - const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { - const rank = variable.rank; - const name = variable.name; - if (rank === 2) { - return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; - } - const batchRank = batchDims.rank; - let resStr = `var ${name}_indices: ${variable.type.indices};`; - for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; - } - broadCastDims.forEach((i) => { - resStr += `\n${name}_indices[${i}] = 0;`; - }); - resStr += `${name}_indices[${rank - 2}] = 0u; - ${name}_indices[${rank - 1}] = 0u;`; - return resStr; - }; - - const calcResult = (): string => { - let calcStr = `var a_data: ${a.type.value};`; - for (let i = 0; i < aComponents; i++) { - calcStr += ` - let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; - } - for (let i = 0; i < outputNumber; i++) { - calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; - - for (let j = 0; j < aComponents; j++) { - calcStr += ` - values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`; - } - } - return calcStr; - }; - - return ` - ${shaderHelper - .registerUniforms(uniforms) - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - let col = (global_idx % (uniforms.N / ${components})) * ${components}; - var index1 = global_idx / (uniforms.N / ${components}); - let stride1 = uniforms.M / ${outputNumber}; - let row = (index1 % stride1) * ${outputNumber}; - let batch = index1 / stride1; - - ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} - ${getIndices(a, broadCastADims)} - let a_offset = ${a.indicesToOffset('a_indices')}; - ${getIndices(b, broadCastBDims)} - let b_offset = ${b.indicesToOffset('b_indices')}; - var values: array<${output.type.value}, ${outputNumber}>; - for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { - ${calcResult()} - } - for (var i = 0u; i < ${outputNumber}u; i++) { - var value = values[i]; - ${processBias} - ${applyActivation} - let cur_indices = ${output.type.indices}(batch, row + i, col); - let offset = ${output.indicesToOffset('cur_indices')}; - ${output.setByOffset(`offset / ${components}`, 'value')}; - } - } - `; - }; - return { - name: 'MatMulNaive', - shaderCache: { - hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, - inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], - }, - getRunData: () => ({ - outputs: [ - { - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - }, - ], - dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, - programUniforms, - }), - getShaderSource, - }; -}; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -201,6 +29,20 @@ export const matMul = (context: ComputeContext): void => { if (N < 8 && K < 8) { context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + const M = outputShape[outputShape.length - 2]; + const batchA = ShapeUtil.size(context.inputs[0].dims.slice(0, -2)); + const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2)); + if (batchA !== 1 && M === 1 && batchB === 1) { + // Optimization for batched vec-mat-mul + const reshapedA = context.inputs[0].reshape([1, batchA, K]); + const reshapedB = context.inputs[1].reshape([1, K, N]); + const matmulOutputShape = [1, batchA, N]; + const matmulInputs = [reshapedA, reshapedB]; + context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), { + inputs: matmulInputs, + }); + } else { + context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + } } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 3f4617014e798..3e1f1be22efa2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -266,9 +266,185 @@ export const createMatMulNBitsProgramInfo = ( }; }; +// Currently, only support blockSize = 32. +export const createMatMulNBitsBlockSize32ProgramInfo = ( + inputs: readonly TensorView[], + attributes: MatMulNBitsAttributes, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); + const blobSize = inputs[1].dims[2]; + const blobSizeInWords = blobSize / 4; + const dataType = inputs[0].dataType; + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); + + const workgroupSize = 128; + const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1; + const workgroupX = workgroupSize / workgroupY; + const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data. + const aLengthPerTile = tileSize / aComponents; + const blocksPerTile = tileSize / attributes.blockSize; + const dispatchSize = ShapeUtil.size(outputShape) / workgroupY; + + const programUniforms: ProgramUniform[] = []; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const readA = () => { + switch (aComponents) { + case 1: + return ` + let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]); + let a_data1 = vec4<${dataType}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`; + case 2: + return ` + let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1]); + let a_data1 = vec4<${dataType}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`; + case 4: + return ` + let a_data0 = sub_a[word_offset]; + let a_data1 = sub_a[word_offset + 1];`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + }; + + return ` + var sub_a: array<${a.type.value}, ${aLengthPerTile}>; + var inter_results: array, ${workgroupY}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([workgroupX, workgroupY, 1])} + let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)}; + let col = output_indices[2]; + let row = output_indices[1]; + let batch = output_indices[0]; + let n_blocks_per_col = uniforms.b_shape[1]; + let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1; + + // Loop over shared dimension. + for (var tile: u32 = 0; tile < num_tiles; tile += 1) { + let a_col_start = tile * ${aLengthPerTile}; + // load one tile A data into shared memory. + for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize}) + { + let a_col = a_col_start + a_offset; + if (a_col < uniforms.a_shape[2]) + { + sub_a[a_offset] = ${a.getByIndices(`${a.type.indices}(batch, row, a_col)`)}; + } else { + sub_a[a_offset] = ${a.type.value}(0); + } + } + workgroupBarrier(); + + // each thread process one block + let b_row = col + local_id.y; + let block = tile * ${blocksPerTile} + local_id.x; + ${ + zeroPoints + ? ` + let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2; + let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u); + let zero_point_word_index = zero_point_byte_count >> 0x2u; + let zero_point_byte_offset = zero_point_byte_count & 0x3u; + let zero_point_nibble_offset: u32 = block & 0x1u; + let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; + let zero_point = ${dataType}((zero_point_word) & 0xFu);` + : ` + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${8.0});` + } + let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)}; + let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)}; + var word_offset = local_id.x * ${attributes.blockSize / aComponents}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + ${readA()} + let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`}; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); + let b_quantized_values = mat2x4<${dataType}>(${Array.from( + { length: 4 }, + (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, + ).join(', ')}); + let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale; + inter_results[local_id.y][local_id.x] += ${Array.from( + { length: 2 }, + (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`, + ).join(' + ')}; + word_offset += ${8 / aComponents}; + } + workgroupBarrier(); + } + + if (local_idx < ${workgroupY}) { + var output_value: ${output.type.value} = ${output.type.value}(0); + for (var b = 0u; b < ${workgroupX}; b++) { + output_value += inter_results[local_idx][b]; + } + if (col + local_idx < uniforms.output_shape[2]) + { + ${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')} + } + } + }`; + }; + return { + name: 'BlockwiseMatMulNBits32', + shaderCache: { + hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`, + inputDependencies: Array(inputs.length).fill('rank'), + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: dispatchSize }, + programUniforms, + }), + getShaderSource, + }; +}; + export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); + if ( + attributes.blockSize === 32 && + context.adapterInfo.isVendor('intel') && + context.adapterInfo.isArchitecture('gen-12lp') + ) { + context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes)); + } else { + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); + } }; export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 0949d65174b41..db7a4b8e68b79 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = ( if (input.dims.length === 3) { reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = ( biasOffset!, ); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], @@ -397,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio ); if (kvBNSH) { - return applyAttention( - context, - Q, - key, - value, - keyPaddingMask, - undefined, - pastKey, - pastValue, - attentionBias, - params, - attributes, - ); + return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); } if (!key || !value) { throw new Error('key and value must be provided'); @@ -436,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index bf64b04dde1e8..fe0c3712197c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -4,7 +4,7 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types'; +import { ComputeContext, ProgramInfo } from '../types'; import { inputVariable, outputVariable, ShaderHelper } from './common'; import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce'; @@ -119,7 +119,7 @@ const getAxesPermutation = (axes: number[], rank: number): number[] => { export const createReduceSharedProgramInfo = ( name: string, - shaderCache: ProgramShaderCacheInfo, + cacheKey: string, inputs: readonly TensorView[], reduceType: string, outputDataType: DataType, @@ -134,7 +134,11 @@ export const createReduceSharedProgramInfo = ( const input = inputVariable('_A', inputs[0].dataType, inputShape); const output = outputVariable('output', outputDataType, outputShape); - const workgroupSize = 32; + let workgroupSize = 64; + // If only one workgroup is dispatched, increase workgroupSize to improve parallelism. + if (outputSize === 1) { + workgroupSize = 256; + } const sharedMemorySnippet = ` var aBestValues : array; @@ -188,7 +192,8 @@ export const createReduceSharedProgramInfo = ( // One work group is responsible for only one element of output. return { name, - shaderCache, + // Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is. + shaderCache: { hint: `${cacheKey};${workgroupSize}`, inputDependencies: ['type'] }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: outputDataType }], @@ -233,7 +238,7 @@ const reduceCommon = ( context.compute( createReduceSharedProgramInfo( name, - { hint: updatedAttributes.cacheKey, inputDependencies: ['type'] }, + updatedAttributes.cacheKey, [input], reduceType, context.inputs[0].dataType, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts new file mode 100644 index 0000000000000..8c24232d63c0c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { + atomicOutputVariable, + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, +} from './common'; + +export interface ScatterNDAttributes extends AttributeWithCacheKey { + reduction: string; +} + +type ReductionType = 'i32' | 'u32' | 'f32'; + +const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => { + if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') { + throw new Error(`Input ${type} is not supported with reduction ${reduction}.`); + } + + const floatStart = `{ + var oldValue = 0; + loop { + let newValueF32 =`; + const floatEnd = `; + let newValue = bitcast(newValueF32); + let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue); + if res.exchanged { + break; + } + oldValue = res.old_value; + } + }`; + + switch (reduction) { + case 'none': + return `${ptr}=${v};`; + case 'add': + if (type === 'i32' || type === 'u32') { + return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicAdd only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`; + } + case 'max': + if (type === 'i32' || type === 'u32') { + return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMax only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}max(bitcast(oldValue), (${v}))${floatEnd}`; + } + case 'min': + if (type === 'i32' || type === 'u32') { + return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMin only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`; + } + case 'mul': + // atomicMul is not supported, we use atomicCompareExchangeWeak to simulate. + return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`; + + default: + throw new Error(`Reduction ${reduction} is not supported.`); + } +}; + +const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + const outputShape = inputShape; + // TODO: support bool with components 4. + const components = 1; + const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); + const lastIndexDimension = indicesShape[indicesShape.length - 1]; + const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: lastIndexDimension }, + { type: DataType.uint32, data: numUpdatesElements }, + ...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length); + const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components); + const output = + attributes.reduction !== 'none' && attributes.reduction !== '' + ? atomicOutputVariable('output', inputs[0].dataType, outputShape.length) + : outputVariable('output', inputs[0].dataType, outputShape.length, components); + + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('last_index_dimension', 'u32') + .registerUniform('num_updates_elements', 'u32') + .declareVariables(indices, updates, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var data_offset = 0u; + let indices_start = uniforms.last_index_dimension * global_idx; + let indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${ + inputs[0].dims.length === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[i - indices_start]; + let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];` + } + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); + } + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); + } + } + data_offset += u32((u32(index) * element_count_dim)); + } + + for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * global_idx + i]; + ${atomicReductionSnippet( + attributes.reduction, + 'output[data_offset + i]', + 'value', + output.type.value as ReductionType, + )} + } + + }`; + }; + return { + name: 'ScatterND', + shaderCache: { + hint: `${attributes.cacheKey}_${attributes.reduction}`, + inputDependencies: ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + +export const parseScatterNDAttributes = (attributes: Record): ScatterNDAttributes => + createAttributeWithCacheKey({ reduction: attributes.reduction as string }); + +export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => { + context.compute(createScatterNDProgramInfo(context.inputs, attributes), { + inputs: [context.inputs[1], context.inputs[2]], + outputs: [], + }); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index fbab44e211946..7c62d1f7182a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -35,7 +35,6 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const input = context.inputs[0]; const inputShape = input.dims; const outputSize = ShapeUtil.size(inputShape); - const WG = 64; const inputRank = inputShape.length; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); const isTransposeRequired = axis < inputShape.length - 1; @@ -60,7 +59,11 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const rows = outputSize / cols; const components = getMaxComponents(cols); const packedCols = cols / components; - + let WG = 64; + // If only one workgroup is dispatched, increase workgroupSize to improve parallelism. + if (rows === 1) { + WG = 256; + } const maxVector = (name: string, components: number) => { if (components === 4) { return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; @@ -95,7 +98,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt result[index] = value; } ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} - ${shaderHelper.mainStart()} + ${shaderHelper.mainStart(WG)} let gindex = i32(global_idx); let lindex = i32(local_idx); const wg = ${WG}; @@ -156,7 +159,8 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const result = context.compute( { name: 'Softmax', - shaderCache: { hint: `${components}`, inputDependencies: ['type'] }, + // Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is. + shaderCache: { hint: `${components};${WG}`, inputDependencies: ['type'] }, getRunData: () => ({ outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }], dispatchGroup: { x: rows }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 1dc3a206cf94b..8c39505734e41 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { +export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 1fd99d085e0ed..5059645211aea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} { var a: ${input.type.indices};`; for (let i = 0; i < rank; ++i) { - reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`); + // input indices and output indices should always be larger or equal to 2, + // so indexer is always valid to be used on `a` and `i`. + reverseFunc += `a[${perm[i]}]=i[${i}];`; } return (reverseFunc += 'return a;}'); }; @@ -48,17 +50,61 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh return { newShape, newPerm }; }; +const isTransposeReshape = (perm: number[], shape: readonly number[]) => { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + let lastPermutedAxis = 0; + for (let i = 0; i < perm.length; ++i) { + if (shape[perm[i]] === 1) { + continue; + } + if (perm[i] < lastPermutedAxis) { + return false; + } + lastPermutedAxis = perm[i]; + } + return true; +}; + export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); const outputShape = getOutputShape(inputTensor.dims, perm); + let newInputShape = inputTensor.dims; + let newOutputShape = outputShape; + const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims); + let getShaderSource; + if (transposeAsReshape) { + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputDataType, newInputShape, 4); + const output = outputVariable('output', inputDataType, newOutputShape, 4); + return ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + output[global_idx] = input[global_idx]; + }`; + }; + + return { + name: 'TransposeCopy', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) }, + programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }], + }; + }, + getShaderSource, + }; + } const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm); const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]); const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]); - const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst; - let newInputShape = useShared ? newShape : inputTensor.dims; - let newOutputShape = outputShape; + const useShared = newShape.length === 2 || channelsLast || channelsFirst; if (useShared) { newInputShape = channelsLast ? [newShape[0], newShape[1] * newShape[2]] @@ -66,13 +112,11 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ? [newShape[0] * newShape[1], newShape[2]] : newShape; newOutputShape = [newInputShape[1], newInputShape[0]]; - } - const input = inputVariable('a', inputDataType, newInputShape.length); - const output = outputVariable('output', inputDataType, newOutputShape.length); - const tileSize = 16; - let getShaderSource; - if (useShared) { - getShaderSource = (shaderHelper: ShaderHelper) => ` + const tileSize = 16; + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} var tile : array, ${tileSize}>; ${shaderHelper.mainStart([tileSize, tileSize, 1])} @@ -92,8 +136,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')} } }`; - } else { - getShaderSource = (shaderHelper: ShaderHelper) => ` + }; + return { + name: 'TransposeShared', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(newInputShape, newOutputShape), + ], + }; + }, + getShaderSource, + }; + } + + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -106,17 +171,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; - } + }; return { - name: useShared ? 'TransposeShared' : 'Transpose', + name: 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], - dispatchGroup: useShared - ? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) } - : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(newInputShape, newOutputShape), diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index c5b8f579c3aae..2c5180c5db3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -93,13 +93,23 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { TRACE_FUNC_BEGIN(programInfo.name); const device = this.backend.device; - const extensions: string[] = []; - if (device.features.has('shader-f16')) { - extensions.push('enable f16;'); - } + const enableDirectives: string[] = []; + + // Enable WGSL extensions based on available WebGPU features + const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ + { feature: 'shader-f16', extension: 'f16' }, + { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, + { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, + ]; + extensionsInfo.forEach((info) => { + if (device.features.has(info.feature)) { + enableDirectives.push(`enable ${info.extension};`); + } + }); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits); const userCode = programInfo.getShaderSource(shaderHelper); - const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; + const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({ code, label: programInfo.name }); LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 776263b143be3..9321ac170d036 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -15,12 +15,17 @@ export enum GpuDataType { } export type GpuDataId = number; -export type GpuArchitecture = 'ampere'; +export type GpuArchitecture = 'ampere' | 'gen-12lp'; export type GpuVendor = 'amd' | 'intel' | 'nvidia'; export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } +export interface DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; +} export interface GpuData { type: GpuDataType; @@ -160,6 +165,11 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; + /** + * gpu device info + */ + readonly deviceInfo: DeviceInfo; + /** * stores the pointer to OpKernelContext */ @@ -187,8 +197,6 @@ export interface ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; - getMaxComputeWorkgroupSizes(): [number, number, number]; - getMaxComputeWorkgroupStoragesize(): number; } export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes'; diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 9475de019ed1d..45b5b8b4fa932 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -42,9 +42,9 @@ export interface TensorManager { download(tensorId: TensorId): Promise; download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise; /** - * Release all tensors for a MLContext. + * Release all tensors for a given session. */ - releaseTensorsForContext(mlContext: MLContext): void; + releaseTensorsForSession(session: number): void; /** * Register an externally created MLTensor with a given MLContext and return a TensorId. */ @@ -54,121 +54,176 @@ export interface TensorManager { let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; -export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]]; +/** + * Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4. + */ +const webnnDataTypeToSize = new Map([ + ['float32', 32], + ['float16', 16], + ['int32', 32], + ['uint32', 32], + ['int64', 64], + ['uint64', 64], + ['int8', 8], + ['uint8', 8], + ['int4', 4], + ['uint4', 4], +]); /** - * TensorTracker tracks the MLTensor and pending upload data. - * - * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until - * we know the data type and shape. This is because future implementations of WebNN will only support creating - * MLTensors with dataTypes and shape. + * Calculate the byte length of a tensor with the given data type and shape. */ -class TensorTracker { - private tensorEntry?: MLTensorEntry; - private activeUpload?: Uint8Array; - private tensorCache: MLTensorEntry[]; +const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => { + const size = webnnDataTypeToSize.get(dataType); + if (!size) { + throw new Error('Unsupported data type.'); + } + return shape.length > 0 ? Math.ceil((shape.reduce((a, b) => a * b) * size) / 8) : 0; +}; - constructor( - private mlContext?: MLContext, - tensorEntry?: MLTensorEntry, - ) { - this.tensorEntry = tensorEntry; - this.tensorCache = tensorEntry ? [tensorEntry] : []; +/** + * TensorWrapper wraps an MLTensor and provides a way to track the last session that used it. + */ +class TensorWrapper { + // The id of the last session that used this tensor. + public sessionId: number; + + private mlContext: MLContext; + private mlTensor: MLTensor; + private dataType: MLOperandDataType; + private tensorShape: readonly number[]; + + constructor(descriptor: { + sessionId: number; + context: MLContext; + tensor: MLTensor; + dataType: MLOperandDataType; + shape: readonly number[]; + }) { + this.sessionId = descriptor.sessionId; + this.mlContext = descriptor.context; + this.mlTensor = descriptor.tensor; + this.dataType = descriptor.dataType; + this.tensorShape = descriptor.shape; } - public get tensor(): MLTensor | undefined { - return this.tensorEntry?.[0]; + public get tensor(): MLTensor { + return this.mlTensor; } - public get context(): MLContext { - if (!this.mlContext) { - throw new Error('MLContext has not been set.'); - } - return this.mlContext; + public get type(): MLOperandDataType { + return this.dataType; } - public set context(mlContext: MLContext) { - if (this.mlContext && this.mlContext !== mlContext) { - throw new Error('MLTensor in use in a different MLContext.'); - } - this.mlContext = mlContext; + public get shape(): readonly number[] { + return this.tensorShape; + } + + public get byteLength(): number { + return calculateByteLength(this.dataType, this.tensorShape); } public destroy(): void { - for (const [mlTensor] of this.tensorCache) { - mlTensor.destroy(); + LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy'); + this.mlTensor.destroy(); + } + + public write(data: Uint8Array): void { + this.mlContext.writeTensor(this.mlTensor, data); + } + + public async read(): Promise; + public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise; + async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + if (dstBuffer) { + return this.mlContext.readTensor(this.mlTensor, dstBuffer); } - this.tensorCache = []; - this.tensorEntry = undefined; + return this.mlContext.readTensor(this.mlTensor); } - public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean { - for (const [mlTensor, dataType, shape] of this.tensorCache) { - if (tryMLTensor === mlTensor) { - if (this.context !== context) { - throw new Error('MLTensor cannot be registered with a different MLContext.'); - } - this.tensorEntry = [mlTensor, dataType, shape]; - return true; - } + public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean { + return ( + this.mlContext === context && + this.dataType === dataType && + this.tensorShape.length === shape.length && + this.tensorShape.every((v, i) => v === shape[i]) + ); + } +} + +/** + * TensorTracker tracks the MLTensor and pending upload data. + * + * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until + * we know the data type and shape. This is because WebNN only support creating MLTensors with dataTypes and shape. + */ +class TensorIdTracker { + private activeUpload?: Uint8Array; + + constructor( + private tensorManager: TensorManagerImpl, + private wrapper?: TensorWrapper, + ) {} + + public get tensorWrapper(): TensorWrapper | undefined { + return this.wrapper; + } + + public releaseTensor(): void { + if (this.tensorWrapper) { + this.tensorManager.releaseTensor(this.tensorWrapper); + this.wrapper = undefined; } - return false; } public async ensureTensor( + context: MLContext, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, ): Promise { - if (this.tensorEntry) { - const [mlTensor, existingDataType, existingShape] = this.tensorEntry; - if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { - return mlTensor; - } - } - - for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) { - if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { - if (copyOld && this.tensorEntry) { - // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. - LOG_DEBUG( - 'verbose', - () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`, - ); - const data = await this.context.readTensor(this.tensorEntry[0]); - this.context.writeTensor(mlTensor, data); + if (this.wrapper) { + if (this.wrapper.canReuseTensor(context, dataType, shape)) { + return this.wrapper.tensor; + } else { + if (copyOld) { + if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) { + throw new Error('Unable to copy data to tensor with different size.'); + } + this.activeUpload = new Uint8Array(await this.wrapper.read()); } - this.tensorEntry = [mlTensor, existingDataType, existingShape]; - return mlTensor; + this.tensorManager.releaseTensor(this.wrapper); } } - LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); + // eslint-disable-next-line no-bitwise - const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; - const tensor = await this.context.createTensor({ - dataType, - shape, - // Assign both shape and dimensions while transitioning to new API. - dimensions: shape, - usage, - }); - this.tensorEntry = [tensor, dataType, shape]; - this.tensorCache.push(this.tensorEntry); + const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; + this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); - if (this.activeUpload) { - this.mlContext?.writeTensor(tensor, this.activeUpload); + if (copyOld && this.activeUpload) { + this.wrapper.write(this.activeUpload); this.activeUpload = undefined; } - return tensor; + return this.wrapper.tensor; } public upload(data: Uint8Array): void { - if (!this.tensorEntry) { + if (this.wrapper) { + if (data.byteLength === this.wrapper.byteLength) { + this.wrapper.write(data); + return; + } else { + LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.'); + this.releaseTensor(); + } + } + + if (this.activeUpload) { + this.activeUpload.set(data); + } else { this.activeUpload = new Uint8Array(data); - return; } - this.mlContext?.writeTensor(this.tensorEntry[0], data); } public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { @@ -179,49 +234,42 @@ class TensorTracker { } else { new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); } - return; } else { return this.activeUpload.buffer; } } - if (!this.tensorEntry) { + if (!this.wrapper) { throw new Error('Tensor has not been created.'); } - if (dstBuffer) { - return this.context.readTensor(this.tensorEntry[0], dstBuffer); + if (!dstBuffer) { + return this.wrapper.read(); } - return this.context.readTensor(this.tensorEntry[0]); + return this.wrapper.read(dstBuffer); } } class TensorManagerImpl implements TensorManager { - private tensorsById = new Map(); - private tensorIdsByContext = new Map>(); + private tensorTrackersById: Map = new Map(); + private freeTensors: TensorWrapper[] = []; + private externalTensors: Set = new Set(); constructor(private backend: WebNNBackend) {} public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); - this.tensorsById.set(tensorId, new TensorTracker()); + this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); return tensorId; } public releaseTensorId(tensorId: TensorId): void { - const tensorTracker = this.tensorsById.get(tensorId); + const tensorTracker = this.tensorTrackersById.get(tensorId); if (!tensorTracker) { return; } - tensorTracker.destroy(); - this.tensorsById.delete(tensorId); - for (const [mlContext, tensors] of this.tensorIdsByContext) { - if (tensors.has(tensorId)) { - tensors.delete(tensorId); - if (tensors.size === 0) { - this.tensorIdsByContext.delete(mlContext); - } - break; - } + this.tensorTrackersById.delete(tensorId); + if (tensorTracker.tensorWrapper) { + this.releaseTensor(tensorTracker.tensorWrapper); } } @@ -238,20 +286,19 @@ class TensorManagerImpl implements TensorManager { dataType }, shape: ${shape}, copyOld: ${copyOld}}`, ); - const tensor = this.tensorsById.get(tensorId); + const tensor = this.tensorTrackersById.get(tensorId); if (!tensor) { throw new Error('Tensor not found.'); } - tensor.context = this.backend.currentContext; - if (!this.tensorIdsByContext.has(this.backend.currentContext)) { - this.tensorIdsByContext.set(this.backend.currentContext, new Set()); - } - this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId); - return tensor.ensureTensor(dataType, shape, copyOld); + return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { - this.tensorsById.get(tensorId)!.upload(data); + const tensor = this.tensorTrackersById.get(tensorId); + if (!tensor) { + throw new Error('Tensor not found.'); + } + tensor.upload(data); } public async download(tensorId: TensorId): Promise; @@ -261,19 +308,20 @@ class TensorManagerImpl implements TensorManager { 'verbose', () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`, ); - return this.tensorsById.get(tensorId)!.download(dstBuffer); + const tensorTracker = this.tensorTrackersById.get(tensorId); + if (!tensorTracker) { + throw new Error('Tensor not found.'); + } + return tensorTracker.download(dstBuffer); } - public releaseTensorsForContext(mlContext: MLContext): void { - const tensors = this.tensorIdsByContext.get(mlContext); - if (!tensors) { - return; - } - for (const tensorId of tensors) { - this.tensorsById.get(tensorId)!.destroy(); - this.tensorsById.delete(tensorId); + public releaseTensorsForSession(sessionId: number): void { + for (const tensor of this.freeTensors) { + if (tensor.sessionId === sessionId) { + tensor.destroy(); + } } - this.tensorIdsByContext.delete(mlContext); + this.freeTensors = this.freeTensors.filter((tensor) => tensor.sessionId !== sessionId); } public registerTensor( @@ -282,20 +330,61 @@ class TensorManagerImpl implements TensorManager { dataType: MLOperandDataType, shape: readonly number[], ): TensorId { - for (const [tensorId, tensorTracker] of this.tensorsById) { - if (tensorTracker.trySelectTensor(mlContext, mlTensor)) { - return tensorId; + const tensorId = createNewTensorId(); + // Defaulting to READ | WRITE if usage is not provided. + // eslint-disable-next-line no-bitwise + const wrapper = new TensorWrapper({ + sessionId: this.backend.currentSessionId, + context: mlContext, + tensor: mlTensor, + dataType, + shape, + }); + this.tensorTrackersById.set(tensorId, new TensorIdTracker(this, wrapper)); + this.externalTensors.add(wrapper); + return tensorId; + } + + /** + * Get or create an MLTensor with the given data type and shape. + */ + public async getCachedTensor( + dataType: MLOperandDataType, + shape: readonly number[], + usage: MLTensorUsageFlags | undefined, + writable: boolean, + readable: boolean, + ): Promise { + const sessionId = this.backend.currentSessionId; + const context = this.backend.currentContext; + for (const [index, tensor] of this.freeTensors.entries()) { + if (tensor.canReuseTensor(context, dataType, shape)) { + LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); + const wrapper = this.freeTensors.splice(index, 1)[0]; + wrapper.sessionId = sessionId; + return wrapper; } } - const tensorId = createNewTensorId(); - this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape])); - let tensors = this.tensorIdsByContext.get(mlContext); - if (!tensors) { - tensors = new Set(); - this.tensorIdsByContext.set(mlContext, tensors); + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); + const tensor = await context.createTensor({ + dataType, + shape, + dimensions: shape, + usage, + writable, + readable, + }); + return new TensorWrapper({ sessionId, context, tensor, dataType, shape }); + } + + /** + * Release tensor for reuse unless external. + */ + public releaseTensor(tensorWrapper: TensorWrapper) { + if (this.externalTensors.has(tensorWrapper)) { + this.externalTensors.delete(tensorWrapper); } - tensors.add(tensorId); - return tensorId; + this.freeTensors.push(tensorWrapper); } } diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 5cb0f4e74c3df..c513b2ec2ed8b 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -13,7 +13,6 @@ type MLPowerPreference = 'default'|'high-performance'|'low-power'; interface MLContextOptions { deviceType?: MLDeviceType; powerPreference?: MLPowerPreference; - numThreads?: number; } interface ML { createContext(options?: MLContextOptions): Promise; @@ -29,7 +28,7 @@ interface MLContext { } interface MLGraph {} type MLInputOperandLayout = 'nchw'|'nhwc'; -type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; +type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'|'int4'|'uint4'; interface MLOperandDescriptor { dataType: MLOperandDataType; shape?: readonly number[]; @@ -37,8 +36,8 @@ interface MLOperandDescriptor { dimensions?: readonly number[]; } interface MLOperand { - dataType(): MLOperandDataType; - shape(): number[]; + dataType: MLOperandDataType; + shape: readonly number[]; } interface MLActivation {} type MLNamedOperands = Record; @@ -393,6 +392,7 @@ type MLNamedTensor = Record; type MLTensorUsageFlags = number; +// TODO(@Honry): Remove this once it is deprecated in Chromium. declare const MLTensorUsage: { readonly WEBGPU_INTEROP: MLTensorUsageFlags; readonly READ: MLTensorUsageFlags; @@ -400,7 +400,11 @@ declare const MLTensorUsage: { }; interface MLTensorDescriptor extends MLOperandDescriptor { - usage: MLTensorUsageFlags; + /** @deprecated Use readable/writeable instead of usage */ + usage: MLTensorUsageFlags | undefined; + importableToWebGPU?: boolean; + readable?: boolean; + writable?: boolean; } interface MLContext { diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index b2594267a595a..17e564247863d 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -200,7 +200,9 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n return [sessionOptionsHandle, allocs]; } catch (e) { if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError("Can't release session options."); + } } allocs.forEach((alloc) => wasm._free(alloc)); throw e; diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index ad2ff62587252..54071866be5c3 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -252,7 +252,9 @@ export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTen type === 'uint64' || type === 'int8' || type === 'uint8' || - type === 'bool'; + type === 'bool' || + type === 'uint4' || + type === 'int4'; /** * Map string data location to integer value diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 0668ac1931988..da8939cd0263a 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -207,12 +207,14 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] => const wasm = getInstance(); const stack = wasm.stackSave(); try { - const dataOffset = wasm.stackAlloc(8); - const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); + const ptrSize = wasm.PTR_SIZE; + const dataOffset = wasm.stackAlloc(2 * ptrSize); + const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize); if (errorCode !== 0) { checkLastError("Can't get session input/output count."); } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + const type = ptrSize === 4 ? 'i32' : 'i64'; + return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))]; } finally { wasm.stackRestore(stack); } @@ -289,25 +291,21 @@ export const createSession = async ( const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { wasm.shouldTransferToMLTensor = false; - if (wasm.currentContext) { - throw new Error('WebNN execution provider is already set.'); - } if (typeof provider !== 'string') { const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, numThreads, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } @@ -318,6 +316,8 @@ export const createSession = async ( checkLastError("Can't create a session."); } + wasm.jsepOnCreateSession?.(); + // clear current MLContext after session creation if (wasm.currentContext) { wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); @@ -399,17 +399,23 @@ export const createSession = async ( outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); if (ioBindingHandle !== 0) { - wasm._OrtReleaseBinding(ioBindingHandle); + if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) { + checkLastError("Can't release IO binding."); + } } if (sessionHandle !== 0) { - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError("Can't release session."); + } } throw e; } finally { wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError("Can't release session options."); + } } allocs.forEach((alloc) => wasm._free(alloc)); @@ -428,16 +434,22 @@ export const releaseSession = (sessionId: number): void => { if (ioBindingState) { if (enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError("Can't clear bound outputs."); + } + } + if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) { + checkLastError("Can't release IO binding."); } - wasm._OrtReleaseBinding(ioBindingState.handle); } wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError("Can't release session."); + } activeSessions.delete(sessionId); }; @@ -455,6 +467,7 @@ export const prepareInputOutputTensor = ( } const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const dataType = tensor[0]; const dims = tensor[1]; @@ -474,7 +487,7 @@ export const prepareInputOutputTensor = ( } if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; const registerBuffer = wasm.jsepRegisterBuffer; @@ -496,15 +509,14 @@ export const prepareInputOutputTensor = ( if (Array.isArray(data)) { // string tensor - dataByteLength = 4 * data.length; + dataByteLength = ptrSize * data.length; rawData = wasm._malloc(dataByteLength); allocs.push(rawData); - let dataIndex = rawData / 4; for (let i = 0; i < data.length; i++) { if (typeof data[i] !== 'string') { throw new TypeError(`tensor data at index ${i} is not a string`); } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } } else { dataByteLength = data.byteLength; @@ -517,8 +529,7 @@ export const prepareInputOutputTensor = ( const stack = wasm.stackSave(); const dimsOffset = wasm.stackAlloc(4 * dims.length); try { - let dimIndex = dimsOffset / 4; - dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d)); + dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64')); const tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(dataType), rawData, @@ -548,6 +559,7 @@ export const run = async ( options: InferenceSession.RunOptions, ): Promise => { const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); @@ -570,10 +582,10 @@ export const run = async ( const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); + const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize); + const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize); + const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize); + const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { // WebNN backend needs the active session to check MLTensors with the current context. @@ -605,17 +617,13 @@ export const run = async ( ); } - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*'); + wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*'); } for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*'); + wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*'); } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { @@ -699,7 +707,7 @@ export const run = async ( const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -708,7 +716,7 @@ export const run = async ( const beforeGetTensorDataStack = wasm.stackSave(); // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const tensorDataOffset = wasm.stackAlloc(4 * ptrSize); let keepOutputTensor = false; let type: Tensor.Type | undefined, @@ -717,24 +725,26 @@ export const run = async ( const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, + tensorDataOffset + ptrSize, + tensorDataOffset + 2 * ptrSize, + + tensorDataOffset + 3 * ptrSize, ); if (errorCode !== 0) { checkLastError(`Can't access output tensor data on index ${i}.`); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + const dataType = Number(wasm.getValue(tensorDataOffset, valueType)); + dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); + const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*'); + const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType)); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType))); + } + if (wasm._OrtFree(dimsOffset) !== 0) { + checkLastError("Can't free memory for tensor dims."); } - wasm._OrtFree(dimsOffset); - const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); @@ -745,10 +755,10 @@ export const run = async ( throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; - let dataIndex = dataOffset / 4; for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + const offset = wasm.getValue(dataOffset + i * ptrSize, '*'); + const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*'); + const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); @@ -776,7 +786,9 @@ export const run = async ( gpuBuffer, download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), dispose: () => { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } }, }, 'gpu-buffer', @@ -833,7 +845,9 @@ export const run = async ( } if (ioBindingState && !enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError("Can't clear bound outputs."); + } activeSessions.set(sessionId, [ sessionHandle, inputNamesUTF8Encoded, diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 3e08fe97f559d..ebeac5dc9e587 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -141,6 +141,12 @@ export declare namespace JSEP { * @param sessionId - specify the session ID. */ jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is + * called. + * @returns + */ + jsepOnCreateSession: () => void; /** * [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is * called. @@ -219,21 +225,45 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; + + /** + * [exported from pre-jsep.js] Register a WebNN Constant operand from external data. + * @param externalFilePath - specify the external file path. + * @param dataOffset - specify the external data offset. + * @param dataLength - specify the external data length. + * @param builder - specify the MLGraphBuilder used for constructing the Constant. + * @param desc - specify the MLOperandDescriptor of the Constant. + * @returns the WebNN Constant operand for the specified external data. + */ + jsepRegisterMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + ): MLOperand; } } export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; - _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; + _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): number; _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; - _OrtReleaseSession(sessionHandle: number): void; + _OrtReleaseSession(sessionHandle: number): number; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; _OrtGetOutputName(sessionHandle: number, index: number): number; - _OrtFree(stringHandle: number): void; + _OrtFree(stringHandle: number): number; _OrtCreateTensor( dataType: number, @@ -250,12 +280,12 @@ export interface OrtInferenceAPIs { dimsOffset: number, dimsLength: number, ): number; - _OrtReleaseTensor(tensorHandle: number): void; + _OrtReleaseTensor(tensorHandle: number): number; _OrtCreateBinding(sessionHandle: number): number; _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; - _OrtClearBoundOutputs(ioBindingHandle: number): void; - _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtClearBoundOutputs(ioBindingHandle: number): number; + _OrtReleaseBinding(ioBindingHandle: number): number; _OrtRunWithBinding( sessionHandle: number, ioBindingHandle: number, @@ -289,11 +319,11 @@ export interface OrtInferenceAPIs { _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; + _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; _OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number; _OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseRunOptions(runOptionsHandle: number): void; + _OrtReleaseRunOptions(runOptionsHandle: number): number; _OrtEndProfiling(sessionHandle: number): number; } @@ -302,10 +332,13 @@ export interface OrtInferenceAPIs { * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { + PTR_SIZE: number; // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; stackAlloc(size: number): number; + getValue(ptr: number, type: string): number; + setValue(ptr: number, value: number, type: string): void; UTF8ToString(offset: number, maxBytesToRead?: number): string; lengthBytesUTF8(str: string): number; diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts index a820fd216ee03..9ce39c366dc77 100644 --- a/js/web/lib/wasm/wasm-utils.ts +++ b/js/web/lib/wasm/wasm-utils.ts @@ -55,10 +55,11 @@ export const checkLastError = (message: string): void => { const stack = wasm.stackSave(); try { - const paramsOffset = wasm.stackAlloc(8); - wasm._OrtGetLastError(paramsOffset, paramsOffset + 4); - const errorCode = wasm.HEAP32[paramsOffset / 4]; - const errorMessagePointer = wasm.HEAPU32[paramsOffset / 4 + 1]; + const ptrSize = wasm.PTR_SIZE; + const paramsOffset = wasm.stackAlloc(2 * ptrSize); + wasm._OrtGetLastError(paramsOffset, paramsOffset + ptrSize); + const errorCode = Number(wasm.getValue(paramsOffset, ptrSize === 4 ? 'i32' : 'i64')); + const errorMessagePointer = wasm.getValue(paramsOffset + ptrSize, '*'); const errorMessage = errorMessagePointer ? wasm.UTF8ToString(errorMessagePointer) : ''; throw new Error(`${message} ERROR_CODE: ${errorCode}, ERROR_MESSAGE: ${errorMessage}`); } finally { diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 6e723a76e8fd8..07c8f0bf3b940 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.20.0", + "version": "1.21.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.20.0", + "version": "1.21.0", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -51,7 +51,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.20.0", + "version": "1.21.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" @@ -839,9 +839,9 @@ "dev": true }, "node_modules/cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true, "engines": { "node": ">= 0.6" @@ -861,9 +861,9 @@ } }, "node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "dependencies": { "nice-try": "^1.0.4", @@ -1116,9 +1116,9 @@ } }, "node_modules/engine.io": { - "version": "6.5.5", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", - "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -1126,7 +1126,7 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", "engine.io-parser": "~5.2.1", @@ -3123,16 +3123,16 @@ } }, "node_modules/socket.io": { - "version": "4.7.5", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", - "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", + "version": "4.8.0", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.0.tgz", + "integrity": "sha512-8U6BEgGjQOfGz3HHTYaC/L1GaxDCJ/KM0XTkJly0EhZ5U/du9uNEZy4ZgYzEzIqlx2CMm25CrCqr1ck899eLNA==", "dev": true, "dependencies": { "accepts": "~1.3.4", "base64id": "~2.0.0", "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.5.2", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", "socket.io-parser": "~4.2.4" }, @@ -4296,9 +4296,9 @@ "dev": true }, "cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true }, "cors": { @@ -4312,9 +4312,9 @@ } }, "cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "requires": { "nice-try": "^1.0.4", @@ -4504,9 +4504,9 @@ } }, "engine.io": { - "version": "6.5.5", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", - "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "requires": { "@types/cookie": "^0.4.1", @@ -4514,7 +4514,7 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", "engine.io-parser": "~5.2.1", @@ -6033,16 +6033,16 @@ "dev": true }, "socket.io": { - "version": "4.7.5", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", - "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", + "version": "4.8.0", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.0.tgz", + "integrity": "sha512-8U6BEgGjQOfGz3HHTYaC/L1GaxDCJ/KM0XTkJly0EhZ5U/du9uNEZy4ZgYzEzIqlx2CMm25CrCqr1ck899eLNA==", "dev": true, "requires": { "accepts": "~1.3.4", "base64id": "~2.0.0", "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.5.2", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", "socket.io-parser": "~4.2.4" } diff --git a/js/web/package.json b/js/web/package.json index d770499adada4..181d6127f5455 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.20.0", + "version": "1.21.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", @@ -78,25 +78,21 @@ "types": "./types.d.ts" }, "./all": { - "node": null, "import": "./dist/ort.all.bundle.min.mjs", "require": "./dist/ort.all.min.js", "types": "./types.d.ts" }, "./wasm": { - "node": null, - "import": "./dist/ort.wasm.min.mjs", + "import": "./dist/ort.wasm.bundle.min.mjs", "require": "./dist/ort.wasm.min.js", "types": "./types.d.ts" }, "./webgl": { - "node": null, "import": "./dist/ort.webgl.min.mjs", "require": "./dist/ort.webgl.min.js", "types": "./types.d.ts" }, "./webgpu": { - "node": null, "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 408f9e00a5cbd..529e9d1065e69 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -591,14 +591,14 @@ async function main() { // ort[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true' }, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); // ort.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.bundle', format: 'esm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, }); // ort.webgpu[.min].[m]js @@ -619,6 +619,13 @@ async function main() { outputName: 'ort.wasm', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); + // ort.wasm.bundle.min.mjs + await buildOrt({ + isProduction: true, + outputName: 'ort.wasm.bundle', + format: 'esm', + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + }); // ort.webgl[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.webgl', diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 5b8b0d27c88db..a07849a154e01 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -14,6 +14,7 @@ // import fs from 'fs'; +import { bootstrap as globalAgentBootstrap } from 'global-agent'; import https from 'https'; import jszip from 'jszip'; import path from 'path'; @@ -111,6 +112,11 @@ console.log( } ===`, ); +// Bootstrap global-agent to honor the proxy settings in +// environment variables, e.g. GLOBAL_AGENT_HTTPS_PROXY. +// See https://github.com/gajus/global-agent/blob/v3.0.0/README.md#environment-variables for details. +globalAgentBootstrap(); + const filter = buildId ? `&buildIds=${buildId}` : '&definitions=161' + diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 613b4507b2b15..8fbe9339feb9b 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -134,6 +134,56 @@ "type": "float32" } ] + }, + { + "name": "Expand in components = 1, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [3, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6 + ], + "dims": [3, 2, 8], + "type": "float32" + } + ] + }, + { + "name": "Expand in components = 4, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + "data": [2, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [1, 2, 2, 8], + "type": "float32" + } + ] } ] }, diff --git a/js/web/test/data/ops/gather-nd.jsonc b/js/web/test/data/ops/gather-nd.jsonc new file mode 100644 index 0000000000000..209c7d1f74087 --- /dev/null +++ b/js/web/test/data/ops/gather-nd.jsonc @@ -0,0 +1,147 @@ +[ + { + "name": "GatherND int32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002], + "dims": [9], + "type": "int32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100, 778, 1002], + "dims": [3], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherND int32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "int32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float16", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float16" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float16" + } + ] + } + ] + }, + { + "name": "GatherND uint32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "uint32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc index 2a4b265078456..f71e89f727cb1 100644 --- a/js/web/test/data/ops/group-query-attention.jsonc +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -1,6 +1,316 @@ [ { - "name": "GroupQueryAttention Basic", + "name": "GroupQueryAttention 0", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 1", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // past value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [40, 41, 42, 43, 44, 45, 46, 47, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 2", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 72, 73, 74, 75, 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 3", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 4", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -12,44 +322,293 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 3, 16], + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + // key, BS* + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // value, BS* + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // past key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // past value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 8, 9, 10, 11, 12, + 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 64, 65, 66, 67, 68, 69, 70, 71, 80, 81, 82, 83, 84, 85, 86, 87, 56, 57, + 58, 59, 60, 61, 62, 63, 72, 73, 74, 75, 76, 77, 78, 79, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 5", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 6", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 7", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], "type": "float32" }, // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], "dims": [1, 3, 8], "type": "float32" }, // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], "dims": [1, 3, 8], "type": "float32" }, // past key, BS* { - "data": null, + "data": [96, 97, 98, 99, 100, 101, 102, 103], + "dims": [1, 1, 1, 8], "type": "float32" }, // past value, BS* { - "data": null, + "data": [104, 105, 106, 107, 108, 109, 110, 111], + "dims": [1, 1, 1, 8], "type": "float32" }, // seqlens_k, unimplemented { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, // total_sequence_length, unimplemented { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -57,22 +616,28 @@ "outputs": [ { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, - 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21 + 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, + 109, 110, 111 ], - "dims": [1, 3, 16], + "dims": [1, 3, 8], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present key, BNSH + "data": [ + 96, 97, 98, 99, 100, 101, 102, 103, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 66, 67, 68, 69, 70, 71 + ], + "dims": [1, 1, 4, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present value, BNSH + "data": [ + 104, 105, 106, 107, 108, 109, 110, 111, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -80,13 +645,12 @@ ] }, { - "name": "GroupQueryAttention Scale", + "name": " GroupQueryAttention 8", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" }, - { "name": "scale", "data": 2.0, "type": "float" } + { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ { @@ -94,38 +658,43 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -135,35 +704,34 @@ "outputs": [ { "data": [ - 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, - 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731 + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 2, 1, 8], "type": "float32" } ] } ] }, - { - "name": "GroupQueryAttention, different sequence length", + "name": "GroupQueryAttention 9", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "num_heads", "data": 2, "type": "int" }, { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ @@ -171,39 +739,41 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 4, 8], + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -212,23 +782,20 @@ ], "outputs": [ { - "data": [ - 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823, - 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, - 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1, - 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815 - ], - "dims": [1, 4, 8], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 1, 8], "type": "float32" }, { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" } ] @@ -236,12 +803,164 @@ ] }, { - "name": "GroupQueryAttention Basic, q k v same head number", + "name": "GroupQueryAttention 10", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 4, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 16], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 11", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [2], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 12", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -249,45 +968,49 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // value, BS* { "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -297,26 +1020,28 @@ "outputs": [ { "data": [ - 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, - 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" }, { + // present value, BNSH "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" } ] @@ -324,12 +1049,12 @@ ] }, { - "name": "GroupQueryAttention, no past kv, used as reference", + "name": "GroupQueryAttention 13", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -337,50 +1062,51 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, + // key, BS* { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, + // value, BS* { "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [4], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -388,29 +1114,28 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -418,12 +1143,12 @@ ] }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1", + "name": "GroupQueryAttention PackedQKV 14", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -431,52 +1156,41 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 1, 32], "type": "float32" }, - // new key, BS* + // key, BS* { - "data": [ - 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value, BS* { - "data": [ - 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -485,38 +1199,121 @@ ], "outputs": [ { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], "type": "float32" }, { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], "type": "float32" } ] } ] }, + // TODO: Uncomment when a bug that is causing the test to fail occasionally, is fixed, or failure is understood. + // { + // "name": "GroupQueryAttention PackedQKV 15", + // "operator": "GroupQueryAttention", + // "opset": { "domain": "com.microsoft", "version": 1 }, + // "attributes": [ + // { "name": "num_heads", "data": 4, "type": "int" }, + // { "name": "kv_num_heads", "data": 2, "type": "int" } + // ], + // "cases": [ + // { + // "name": "T[0]", + // "inputs": [ + // { + // "data": [ + // 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, + // 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + // 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, + // 22, 21, 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, + // 1, 3, 4, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, + // 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, + // 2, 131, 22, 21 + // ], + // "dims": [1, 3, 64], + // "type": "float32" + // }, + // // key + // { + // "data": null, + // "type": "float32" + // }, + // // value + // { + // "data": null, + // "type": "float32" + // }, + // // pask key, BNSH + // { + // "data": [], + // "dims": [1, 2, 0, 8], + // "type": "float32" + // }, + // // pask value, BNSH + // { + // "data": [], + // "dims": [1, 2, 0, 8], + // "type": "float32" + // }, + // // seqlens_k + // { + // "data": [3], + // "dims": [1], + // "type": "int32" + // }, + // // total_sequence_length + // { + // "data": [3], + // "dims": [1], + // "type": "int32" + // } + // ], + // "outputs": [ + // { + // "data": [ + // 1, 9, 1, 1, 2, 2, 2, 2, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 1, 12, 21, 131, 22, 21, 2, + // 2, 8, 12, 233, 4, 5, 6, 7, 8, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4, + // 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4 + // ], + // "dims": [1, 3, 32], + // "type": "float32" + // }, + // { + // // present key, BNSH + // "data": [ + // 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 131, 22, 21, 2, 2, 131, 22, 21, 5, 6, 7, 8, 1, 1, 3, 4, + // 8, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 2, 2, 2 + // ], + // "dims": [1, 2, 3, 8], + // "type": "float32" + // }, + // { + // // present value, BNSH + // "data": [ + // 1, 9, 1, 1, 2, 2, 2, 2, 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + // 5, 6, 7, 8, 1, 1, 3, 4, 131, 22, 21, 2, 2, 131, 22, 21 + // ], + // "dims": [1, 2, 3, 8], + // "type": "float32" + // } + // ] + // } + // ] + // }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2", + "name": "GroupQueryAttention PackedQKV 16", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -529,54 +1326,50 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191 ], - "dims": [1, 7, 16], + "dims": [1, 3, 64], "type": "float32" }, - // new key, BS* + // key { - "data": [ - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, - 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value { - "data": [ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } @@ -584,29 +1377,33 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63, 112, 113, 114, 115, 116, 117, 118, 119, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 120, 121, 122, 123, 124, 125, 126, 127, 176, 177, 178, 179, 180, + 181, 182, 183, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 184, 185, + 186, 187, 188, 189, 190, 191 ], - "dims": [1, 7, 16], + "dims": [1, 3, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 96, 97, 98, 99, 100, 101, 102, 103, 160, 161, 162, 163, 164, 165, 166, + 167, 40, 41, 42, 43, 44, 45, 46, 47, 104, 105, 106, 107, 108, 109, 110, 111, 168, 169, 170, 171, 172, 173, + 174, 175 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 48, 49, 50, 51, 52, 53, 54, 55, 112, 113, 114, 115, 116, 117, 118, 119, 176, 177, 178, 179, 180, 181, 182, + 183, 56, 57, 58, 59, 60, 61, 62, 63, 120, 121, 122, 123, 124, 125, 126, 127, 184, 185, 186, 187, 188, 189, + 190, 191 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" } ] diff --git a/js/web/test/data/ops/matmul.jsonc b/js/web/test/data/ops/matmul.jsonc index 2c2cf509d7e3e..f5996db1aecb6 100644 --- a/js/web/test/data/ops/matmul.jsonc +++ b/js/web/test/data/ops/matmul.jsonc @@ -95,6 +95,56 @@ } ] }, + { + "name": "multiplies 3D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [6, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [6, 1, 3], + "type": "float32" + } + ] + }, + { + "name": "multiplies 4D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [2, 3, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [2, 3, 1, 3], + "type": "float32" + } + ] + }, { "name": "multiplies 4D tensors", "inputs": [ @@ -313,6 +363,100 @@ "type": "float32" } ] + }, + { + "name": "same ranks different broadcast small 0", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [8, 9, 10, 11], + "dims": [2, 1, 2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 43, 77, 111, 11, 53, 95, 137], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast small 1", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 1, 2, 2], + "type": "float32" + }, + { + "data": [8, 9, 10, 11], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 43, 11, 53, 77, 111, 95, 137], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast larger 0", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 2, 2, 8], + "type": "float32" + }, + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [2, 1, 8, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1036, 3308, 5580, 7852, 1260, 4044, 6828, 9612], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast larger 1", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [2, 1, 2, 8], + "type": "float32" + }, + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1036, 3308, 1260, 4044, 5580, 7852, 6828, 9612], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/data/ops/scatternd.jsonc b/js/web/test/data/ops/scatternd.jsonc new file mode 100644 index 0000000000000..5135bb9e4d3a5 --- /dev/null +++ b/js/web/test/data/ops/scatternd.jsonc @@ -0,0 +1,472 @@ +[ + { + "name": "ScatterND int32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 11, 3, 10, 9, 6, 7, 12], + "dims": [8], + "type": "int32" + } + ] + }, + { + "name": "int32", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "int64" + }, + { + "data": [ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131 + ], + "dims": [2, 4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND float32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND add int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 13, 3, 14, 14, 6, 7, 20], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND add float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 13.5, 3.0999999046325684, 14.699999809265137, 14.40000057220459, 6.099999904632568, + 7.800000190734863, 21.399999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND mul int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND min int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND max int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 113, 31, 102, 91, 61, 78, 125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND mul float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 24.860000610351562, 3.0999999046325684, 45.89999771118164, 48.230003356933594, + 6.099999904632568, 7.800000190734863, 111.24999237060547 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND min float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 2.200000047683716, 3.0999999046325684, 4.5, 5.300000190734863, 6.099999904632568, + 7.800000190734863, 8.899999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND max float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 11.300000190734863, 3.0999999046325684, 10.199999809265137, 9.100000381469727, + 6.099999904632568, 7.800000190734863, 12.5 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND float16", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 11 }, + "cases": [ + { + "name": "float16", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float16" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float16" + } + ] + } + ] + }, + { + "name": "ScatterND mul uint32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "uint32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "uint32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index a7265d6444118..d431ceb1712a5 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -263,6 +263,30 @@ } ] }, + { + "name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 3, 2, 4, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Transpose - perms:[1, 0]", "operator": "Transpose", diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js index 471c26f6990b5..27cce2ca06236 100644 --- a/js/web/test/e2e/browser-test-wasm-binary-override.js +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -7,7 +7,7 @@ const documentUrl = document.currentScript.src; it('Browser E2E testing - WebAssembly backend', async function () { // preload .wasm file binary - const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; + const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.jsep.wasm', documentUrl).href; const response = await fetch(wasmUrl); // make sure the .wasm file is loaded successfully diff --git a/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js b/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js new file mode 100644 index 0000000000000..d325a5ca7187d --- /dev/null +++ b/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +it('Browser E2E testing - WebAssembly backend (path override filename)', async function () { + // check base URL port from test args + if (typeof __ort_arg_port === 'undefined') { + throw new Error('test flag --port= is required'); + } + const base = `http://localhost:${__ort_arg_port}/`; + + ort.env.wasm.wasmPaths = {}; + + if (typeof __ort_arg_files === 'string' && __ort_arg_files.includes('wasm')) { + const overrideWasmUrl = new URL('./test-wasm-path-override/jsep-renamed.wasm', base).href; + console.log(`ort.env.wasm.wasmPaths['wasm'] = ${JSON.stringify(overrideWasmUrl)};`); + ort.env.wasm.wasmPaths.wasm = overrideWasmUrl; + } + + if (typeof __ort_arg_files === 'string' && __ort_arg_files.includes('mjs')) { + const overrideMjsUrl = new URL('./test-wasm-path-override/jsep-renamed.mjs', base).href; + console.log(`ort.env.wasm.wasmPaths['mjs'] = ${JSON.stringify(overrideMjsUrl)};`); + ort.env.wasm.wasmPaths.mjs = overrideMjsUrl; + } + + await testFunction(ort, { executionProviders: ['wasm'] }); +}); diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 04079b042bc23..dbc3ca0bd2460 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -14,7 +14,7 @@ const NODEJS_TEST_CASES = [ // [test_for_same_origin, test_for_cross_origin, main_js, ort_main_js, [test_args]] const BROWSER_TEST_CASES = [ // IIFE - [true, true, './browser-test-webgl.js', 'ort.min.js'], // webgl + [true, true, './browser-test-webgl.js', 'ort.all.min.js'], // webgl [true, true, './browser-test-webgl.js', 'ort.webgl.min.js'], // webgl [true, true, './browser-test-wasm.js', 'ort.wasm.min.js'], // wasm, ort.wasm [true, true, './browser-test-wasm-multi-session-create.js', 'ort.min.js'], // wasm, multi-session create @@ -24,7 +24,7 @@ const BROWSER_TEST_CASES = [ [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy // ort.min.mjs - [true, true, './browser-test-webgl.js', 'ort.min.mjs'], // webgl + [true, true, './browser-test-webgl.js', 'ort.webgl.min.mjs'], // webgl [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1']], // wasm, 1 thread [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2']], // wasm, 2 threads [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy @@ -41,22 +41,22 @@ const BROWSER_TEST_CASES = [ // path override: // wasm, path override filenames for both mjs and wasm, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=mjs,wasm']], // wasm, path override filenames for both mjs and wasm, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=mjs,wasm']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=mjs,wasm']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=mjs,wasm']], // wasm, path override filename for wasm, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=wasm']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=wasm']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=wasm']], // wasm, path override filename for wasm, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=wasm']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=wasm']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=wasm']], // wasm, path override filename for mjs, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=mjs']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=mjs']], // wasm, path override filename for mjs, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=mjs']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=mjs']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=mjs']], // wasm, path override prefix, same origin [true, false, './browser-test-wasm-path-override-prefix.js', 'ort.min.js', ['port=9876']], diff --git a/js/web/test/e2e/run.js b/js/web/test/e2e/run.js index 93f9d4a144bf2..3361bbece64ed 100644 --- a/js/web/test/e2e/run.js +++ b/js/web/test/e2e/run.js @@ -146,6 +146,10 @@ function prepareWasmPathOverrideFiles() { fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'ort-wasm-simd-threaded.wasm')); fs.copyFileSync(`${sourceFile}.mjs`, path.join(folder, 'renamed.mjs')); fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'renamed.wasm')); + fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'ort-wasm-simd-threaded.jsep.mjs')); + fs.copyFileSync(`${sourceFile}.jsep.wasm`, path.join(folder, 'ort-wasm-simd-threaded.jsep.wasm')); + fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'jsep-renamed.mjs')); + fs.copyFileSync(`${sourceFile}.jsep.wasm`, path.join(folder, 'jsep-renamed.wasm')); } async function testAllNodejsCases() { diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 95fe60b2c79be..1c4763d0f22d8 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -570,14 +570,14 @@ "test_greater_equal_expanded", "test_greater_equal", "test_greater", - // // "test_gridsample_aligncorners_true", - // // "test_gridsample_bicubic", - // // "test_gridsample_bilinear", - // // "test_gridsample_border_padding", - // // "test_gridsample_nearest", - // // "test_gridsample_reflection_padding", - // // "test_gridsample_zeros_padding", - // // "test_gridsample", + "test_gridsample_aligncorners_true", + "test_gridsample_bicubic", + "test_gridsample_bilinear", + "test_gridsample_border_padding", + "test_gridsample_nearest", + "test_gridsample_reflection_padding", + "test_gridsample_zeros_padding", + "test_gridsample", // // "test_gru_batchwise", // // "test_gru_defaults", // // "test_gru_seq_length", @@ -1365,6 +1365,7 @@ "gather.jsonc", "gather-block-quantized.jsonc", "gather-elements.jsonc", + "gather-nd.jsonc", "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", @@ -1396,6 +1397,7 @@ "pow-big-number.jsonc", "reshape.jsonc", "rotary-embedding.jsonc", + "scatternd.jsonc", "simplified-layer-norm.jsonc", "skip-layer-norm.jsonc", "skip-simplified-layer-norm.jsonc", @@ -1532,14 +1534,14 @@ "test_add_bcast", // "test_add_uint8", "test_add", - // "test_and_bcast3v1d", - // "test_and_bcast3v2d", - // "test_and_bcast4v2d", - // "test_and_bcast4v3d", - // "test_and_bcast4v4d", - // "test_and2d", - // "test_and3d", - // "test_and4d", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_and2d", + "test_and3d", + "test_and4d", "test_argmax_default_axis_example_select_last_index", "test_argmax_default_axis_example", "test_argmax_default_axis_random_select_last_index", @@ -1699,13 +1701,13 @@ "test_cos", // "test_cosh_example", // "test_cosh", - // "test_cumsum_1d_exclusive", - // "test_cumsum_1d_reverse_exclusive", - // "test_cumsum_1d_reverse", - // "test_cumsum_1d", - // "test_cumsum_2d_axis_0", - // "test_cumsum_2d_axis_1", - // "test_cumsum_2d_negative_axis", + "test_cumsum_1d_exclusive", + "test_cumsum_1d_reverse_exclusive", + "test_cumsum_1d_reverse", + "test_cumsum_1d", + "test_cumsum_2d_axis_0", + "test_cumsum_2d_axis_1", + "test_cumsum_2d_negative_axis", // "test_depthtospace_crd_mode_example", // "test_depthtospace_crd_mode", // "test_depthtospace_dcr_mode", @@ -1777,9 +1779,9 @@ "test_gather_elements_0", "test_gather_elements_1", "test_gather_elements_negative_indices", - // "test_gathernd_example_float32", - // "test_gathernd_example_int32_batch_dim1", - // "test_gathernd_example_int32", + "test_gathernd_example_float32", + "test_gathernd_example_int32_batch_dim1", + "test_gathernd_example_int32", "test_gemm_all_attributes", "test_gemm_alpha", "test_gemm_beta", @@ -2089,14 +2091,14 @@ // // "test_optional_get_element", // // "test_optional_has_element_empty", // // "test_optional_has_element", - // "test_or_bcast3v1d", - // "test_or_bcast3v2d", - // "test_or_bcast4v2d", - // "test_or_bcast4v3d", - // "test_or_bcast4v4d", - // "test_or2d", - // "test_or3d", - // "test_or4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_or2d", + "test_or3d", + "test_or4d", "test_pow_bcast_array", "test_pow_bcast_scalar", "test_pow_example", @@ -2254,15 +2256,15 @@ // // "test_round", // // "test_scan_sum", // // "test_scan9_sum", - // // "test_scatter_elements_with_axis", - // // "test_scatter_elements_with_duplicate_indices", - // // "test_scatter_elements_with_negative_indices", - // // "test_scatter_elements_without_axis", + "test_scatter_elements_with_axis", + "test_scatter_elements_with_duplicate_indices", + "test_scatter_elements_with_negative_indices", + "test_scatter_elements_without_axis", // // "test_scatter_with_axis", // // "test_scatter_without_axis", - // // "test_scatternd_add", - // // "test_scatternd_multiply", - // // "test_scatternd", + "test_scatternd_add", + "test_scatternd_multiply", + "test_scatternd", // // "test_sce_mean_3d_expanded", // // "test_sce_mean_3d_log_prob_expanded", // // "test_sce_mean_3d_log_prob", @@ -2352,7 +2354,7 @@ // "test_shrink_soft", "test_sigmoid_example", "test_sigmoid", - // "test_sign", + "test_sign", // "test_simple_rnn_batchwise", // "test_simple_rnn_defaults", // "test_simple_rnn_with_initial_bias", @@ -2362,14 +2364,14 @@ // "test_sinh", // // "test_size_example", // // "test_size", - // "test_slice_default_axes", - // "test_slice_default_steps", - // "test_slice_end_out_of_bounds", - // "test_slice_neg_steps", - // "test_slice_neg", - // "test_slice_negative_axes", - // "test_slice_start_out_of_bounds", - // "test_slice", + "test_slice_default_axes", + "test_slice_default_steps", + "test_slice_end_out_of_bounds", + "test_slice_neg_steps", + "test_slice_neg", + "test_slice_negative_axes", + "test_slice_start_out_of_bounds", + "test_slice", // "test_softmax_axis_0_expanded", "test_softmax_axis_0", // "test_softmax_axis_1_expanded", @@ -2550,16 +2552,16 @@ "test_unsqueeze", // "test_wrap_pad" // "test_upsample_nearest", - "test_where_example" + "test_where_example", // "test_where_long_example", - // "test_xor_bcast3v1d", - // "test_xor_bcast3v2d", - // "test_xor_bcast4v2d", - // "test_xor_bcast4v3d", - // "test_xor_bcast4v4d", - // "test_xor2d", - // "test_xor3d", - // "test_xor4d" + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", + "test_xor2d", + "test_xor3d", + "test_xor4d" ], "ops": [] } diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index a8945222b485a..5de39535a5c07 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -291,14 +291,9 @@ export class ModelTestContext { if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) { const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption; const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads; const powerPreference = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.powerPreference; - mlContext = await navigator.ml.createContext({ - deviceType, - numThreads, - powerPreference, - }); + mlContext = await navigator.ml.createContext({ deviceType, powerPreference }); (executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption).context = mlContext; if (!deviceType) { (executionProviderConfig as ort.InferenceSession.WebNNContextOptions).deviceType = deviceType; @@ -591,11 +586,11 @@ export class TensorResultValidator { } } -function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { +async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise { if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); } - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -617,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { }); } -function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { +async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { if (!isGpuBufferSupportedType(type)) { throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); } const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!; - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -666,7 +661,8 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty shape: dims as number[], // Assign both shape and dimensions while transitioning to new API. dimensions: dims as number[], - usage: MLTensorUsage.READ, + usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ, + readable: true, }); return ort.Tensor.fromMLTensor(mlTensor, { @@ -690,7 +686,8 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso shape: cpuTensor.dims as number[], // Assign both shape and dimensions while transitioning to new API. dimensions: cpuTensor.dims as number[], - usage: MLTensorUsage.WRITE, + usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.WRITE, + writable: true, }); mlContext.writeTensor(mlTensor, cpuTensor.data); return ort.Tensor.fromMLTensor(mlTensor, { @@ -728,7 +725,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') { feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]); } else { - feeds[name] = createGpuTensorForInput(feeds[name]); + feeds[name] = await createGpuTensorForInput(feeds[name]); } } } @@ -745,7 +742,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-tensor') { fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims); } else { - fetches[name] = createGpuTensorForOutput(type, dims); + fetches[name] = await createGpuTensorForOutput(type, dims); } } } diff --git a/objectivec/error_utils.mm b/objectivec/error_utils.mm index 335cf8894d549..e8d4d5bb365c9 100644 --- a/objectivec/error_utils.mm +++ b/objectivec/error_utils.mm @@ -11,7 +11,7 @@ void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSE if (!error) return; NSString* description = [NSString stringWithCString:descriptionCstr - encoding:NSASCIIStringEncoding]; + encoding:NSUTF8StringEncoding]; *error = [NSError errorWithDomain:kOrtErrorDomain code:code diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index 6ff18176ebeb2..41d15aa39453a 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -29,7 +29,10 @@ NS_ASSUME_NONNULL_BEGIN * Whether the CoreML execution provider should run on CPU only. */ @property BOOL useCPUOnly; - +/** + * exclude ANE in CoreML. + */ +@property BOOL useCPUAndGPU; /** * Whether the CoreML execution provider is enabled on subgraphs. */ @@ -67,7 +70,22 @@ NS_ASSUME_NONNULL_BEGIN */ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options error:(NSError**)error; - +/** + * Enables the CoreML execution provider in the session configuration options. + * It is appended to the execution provider list which is ordered by + * decreasing priority. + * + * @param provider_options The CoreML execution provider options in dict. + * available keys-values: more detail in core/providers/coreml/coreml_execution_provider.h + * kCoremlProviderOption_MLComputeUnits: one of "CPUAndNeuralEngine", "CPUAndGPU", "CPUOnly", "All" + * kCoremlProviderOption_ModelFormat: one of "MLProgram", "NeuralNetwork" + * kCoremlProviderOption_RequireStaticInputShapes: "1" or "0" + * kCoremlProviderOption_EnableOnSubgraphs: "1" or "0" + * @param error Optional error information set if an error occurs. + * @return Whether the provider was enabled successfully. + */ +- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options + error:(NSError**)error; @end NS_ASSUME_NONNULL_END diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 58b47d68eea63..0c790a91fb8b9 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -25,6 +25,7 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti try { const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | + (options.useCPUAndGPU ? COREML_FLAG_USE_CPU_AND_GPU : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | @@ -42,6 +43,21 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti #endif } +- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options + error:(NSError**)error { +#if ORT_OBJC_API_COREML_EP_AVAILABLE + try { + return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error]; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error); + +#else // !ORT_OBJC_API_COREML_EP_AVAILABLE + static_cast(provider_options); + ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error); + return NO; +#endif +} + @end NS_ASSUME_NONNULL_END diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm index 508289f7bc748..409ee7e1584e2 100644 --- a/objectivec/test/ort_session_test.mm +++ b/objectivec/test/ort_session_test.mm @@ -223,6 +223,28 @@ - (void)testAppendCoreMLEP { ORTAssertNullableResultSuccessful(session, err); } +- (void)testAppendCoreMLEP_v2 { + NSError* err = nil; + ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; + NSDictionary* provider_options = @{@"EnableOnSubgraphs" : @"1"}; // set an arbitrary option + + BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptionsV2:provider_options + error:&err]; + + if (!ORTIsCoreMLExecutionProviderAvailable()) { + ORTAssertBoolResultUnsuccessful(appendResult, err); + return; + } + + ORTAssertBoolResultSuccessful(appendResult, err); + + ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv + modelPath:[ORTSessionTest getAddModelPath] + sessionOptions:sessionOptions + error:&err]; + ORTAssertNullableResultSuccessful(session, err); +} + - (void)testAppendXnnpackEP { NSError* err = nil; ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0e9a924bde4bb..9d533af616288 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.20.0" +__version__ = "1.21.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index e0fa581c8071d..97d6cc1ce7d66 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -48,6 +48,7 @@ enum AttentionKernelType { AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_FlashAttention, AttentionKernel_CudnnFlashAttention, + AttentionKernel_LeanAttention, AttentionKernel_Default }; @@ -65,7 +66,6 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; - int num_splits; int rotary_embedding; bool is_unidirectional; bool past_present_share_buffer; @@ -79,6 +79,45 @@ struct AttentionParameters { AttentionQkvFormat qkv_format; }; +struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { + int beam_width = 1; + + // Only NeoX style rotary embedding is supported + int rotary_embedding_dim = 0; + int t_step = 0; + + // Whether to use multihead attention(excludes matmul and bias) + bool is_mha = false; + bool is_cross_attention = false; + bool is_packed_qkv = false; + + // Useful to better use global memory bandwidth on certain CUDA architectures. + // Turned off by default for now until we fully understand performance implications + // for all types of workloads. + // Can be turned on by appropriate environment variable (see attention_common.h). + bool kv_data_in_flight = false; + + void* q = nullptr; + void* q_bias = nullptr; + + void* k = nullptr; + void* k_bias = nullptr; + + void* v = nullptr; + void* v_bias = nullptr; + + void* attention_bias = nullptr; + + void* k_cache = nullptr; + void* v_cache = nullptr; + + void* out = nullptr; + void* out_qk = nullptr; + + const int32_t* cache_indir = nullptr; + const int32_t* mask = nullptr; // [B, total_sequence_length] +}; + // Parameters deduced from node attributes and inputs/outputs. struct PackedAttentionParameters { int batch_size; @@ -169,10 +208,13 @@ enum class AttentionBackend : int { CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. MATH = 16, // unfused kernel cannot be disabled right now. - // The following kernels might be deprecated in the future. + // The following TRT kernels might be deprecated in the future. TRT_FLASH_ATTENTION = 32, TRT_CROSS_ATTENTION = 64, TRT_CAUSAL_ATTENTION = 128, + + // Experimental kernels + LEAN_ATTENTION = 256, }; // Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). @@ -200,6 +242,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; +// Environment variable to enable or disable lean attention. Default is 0 (disabled). +constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION"; + // Minimum sequence length to perfer memory efficient attention when data type is float32 constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index ae2eaf0204026..87938f3728750 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -36,18 +36,22 @@ class AttentionCPUBase : public AttentionBase { int v_head_size, // head size of V (H_v) int v_hidden_size, // hidden size of V (D_v) const Tensor* attn_bias, // additive bias applied on scaled QK. - OpKernelContext* context) const { + OpKernelContext* context, + Tensor* output_qk = nullptr, // output buffer for QK (if needed) + int past_sequence_length = 0, // sequence length of past state + bool past_present_share_buffer = false) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - int past_sequence_length = 0; Tensor* present = nullptr; - if (present_key == nullptr && present_value == nullptr) { - present = GetPresent(context, past, batch_size, v_head_size, kv_sequence_length, past_sequence_length); - } else if (past_key != nullptr && past_value != nullptr) { - past_sequence_length = static_cast(past_key->Shape().GetDims()[2]); + if (past_sequence_length == 0) { + if (present_key == nullptr && present_value == nullptr) { + present = GetPresent(context, past, batch_size, v_head_size, kv_sequence_length, past_sequence_length); + } else if (past_key != nullptr && past_value != nullptr) { + past_sequence_length = static_cast(past_key->Shape().GetDims()[2]); + } } // Total sequence length including that of past state: T = P + L @@ -71,9 +75,9 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). - // Merge padding mask with causual mask, and broadcast to 3D (BxSxT). + // Merge padding mask with causal mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), - causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); + causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); } @@ -85,10 +89,18 @@ class AttentionCPUBase : public AttentionBase { T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + T* output_qk_data = output_qk != nullptr ? output_qk->MutableData() : nullptr; const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr; auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; + // Used for DecoderMaskedMultiHeadAttention + int max_sequence_length = 0; + if (past_present_share_buffer) { + ORT_ENFORCE(past_key != nullptr && past_value != nullptr); + max_sequence_length = static_cast(past_key->Shape().GetDims()[2]); + } + // Compute the attention score. size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); auto attention_probs = allocator->Alloc(bytes); @@ -96,8 +108,9 @@ class AttentionCPUBase : public AttentionBase { ComputeAttentionProbs(static_cast(attention_probs), Q, K, static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, - qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims); + qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data, + present_key_data, tp, scale, attn_bias_data, attn_bias_dims, output_qk_data, + past_present_share_buffer, max_sequence_length); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -106,7 +119,8 @@ class AttentionCPUBase : public AttentionBase { ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size, - v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp); + v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp, + past_present_share_buffer, max_sequence_length); return Status::OK(); } @@ -117,29 +131,32 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - T* mask_data, // buffer for mask data. - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int kv_sequence_length, // sequence length of cross-attention (L) - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - const T* past_key, // past key only (if not using past state) - T* present, // present state - T* present_key, // present key only (if not using present state) - ThreadPool* tp, // thread pool - float scale, // scale factor - const T* attn_bias_data, // attention bias - gsl::span attn_bias_dims // attention bias shape - ) const { + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int kv_sequence_length, // sequence length of cross-attention (L) + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + const T* past_key, // past key only (if not using past state) + T* present, // present state + T* present_key, // present key only (if not using present state) + ThreadPool* tp, // thread pool + float scale, // scale factor + const T* attn_bias_data, // attention bias + gsl::span attn_bias_dims, // attention bias shape + T* output_qk_data = nullptr, // scaled output QK buffer + bool past_present_share_buffer = false, + int max_sequence_length = 0) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H const size_t kv_input_chunk_length = static_cast(kv_sequence_length) * head_size; // L x H const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H + const size_t cache_chunk_length = static_cast(max_sequence_length) * head_size; // M x H DUMP_CPU_TENSOR_INIT(); DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); @@ -164,7 +181,8 @@ class AttentionCPUBase : public AttentionBase { } if (present || present_key) { - double bytes_to_copy_key = static_cast(sizeof(T) * present_chunk_length); + double bytes_to_copy_key = (past_present_share_buffer ? kv_input_chunk_length : present_chunk_length) * + static_cast(sizeof(T)); unit_cost.bytes_loaded += bytes_to_copy_key; unit_cost.bytes_stored += bytes_to_copy_key; } @@ -214,7 +232,12 @@ class AttentionCPUBase : public AttentionBase { // Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i); } else if (nullptr != present_key) { - k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i); + if (past_present_share_buffer) { + k = present_key + cache_chunk_length * i; + memcpy(const_cast(k) + past_chunk_length, K + head_size * i, head_size * sizeof(T)); + } else { + k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i); + } } // Compute Q*K' + AttentionMask @@ -230,6 +253,12 @@ class AttentionCPUBase : public AttentionBase { }); } + if (output_qk_data != nullptr) { + // Output the scaled Q*K^T if needed. + memcpy(output_qk_data, attention_probs, + SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T)); + } + DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length); // attention_probs(B, N, S, T) = Softmax(attention_probs) @@ -257,12 +286,15 @@ class AttentionCPUBase : public AttentionBase { const T* past_value, // past value only (if not using past state) T* present, // present state T* present_value, // present value only (if not using present state) - ThreadPool* tp) const { + ThreadPool* tp, + bool past_present_share_buffer = false, + int max_sequence_length = 0) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v const ptrdiff_t q_input_chunk_length = SafeInt(sequence_length) * v_head_size; // S x H_v const ptrdiff_t kv_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v const ptrdiff_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H_v + const ptrdiff_t cache_chunk_length = SafeInt(max_sequence_length) * v_head_size; // M x H_v // Move the pointer of past and present to start of v values. if (nullptr != past) { @@ -281,7 +313,8 @@ class AttentionCPUBase : public AttentionBase { unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); if (present || present_value) { - double bytes_to_copy_value = static_cast(present_chunk_length * sizeof(T)); + double bytes_to_copy_value = (past_present_share_buffer ? kv_input_chunk_length : present_chunk_length) * + static_cast(sizeof(T)); unit_cost.bytes_loaded += bytes_to_copy_value; unit_cost.bytes_stored += bytes_to_copy_value; } @@ -299,7 +332,12 @@ class AttentionCPUBase : public AttentionBase { // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); } else if (nullptr != present_value) { - v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); + if (past_present_share_buffer) { + v = present_value + cache_chunk_length * i; + memcpy(const_cast(v) + past_chunk_length, V + v_head_size * i, v_head_size * sizeof(T)); + } else { + v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); + } } T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 4d435f71cc195..37bb5664393c9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -120,9 +120,10 @@ void PrepareMask(const int32_t* mask_index, bool causal, int batch_size, int sequence_length, + int kv_sequence_length, int past_sequence_length, float mask_filter_value) { - const int all_sequence_length = past_sequence_length + sequence_length; + const int all_sequence_length = past_sequence_length + kv_sequence_length; // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc new file mode 100644 index 0000000000000..e6f65f92e14f4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -0,0 +1,473 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention_cpu_base.h" +#include "attention_utils.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/cpu/bert/decoder_masked_multihead_attention.h" + +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { + +// TODO: refactor +static constexpr int kPastSequenceLengthInputIndex = 7; +static constexpr int kBeamWidthInputIndex = 8; +static constexpr int kCacheIndirectionInputIndex = 9; +static constexpr int kPastInputIndex = 5; +static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; +static constexpr int kBiasIndex = 10; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ + DecoderMaskedMultiHeadAttention); + +REGISTER_KERNEL_TYPED(float) + +template +DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const OpKernelInfo& info) + : OpKernel(info), AttentionCPUBase(info, false) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); +} + +template +Status DecoderMaskedMultiHeadAttention::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* attention_bias = context->Input(4); + const Tensor* past_key = context->Input(kPastInputIndex); + const Tensor* past_value = context->Input(kPastInputIndex + 1); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + const Tensor* beam_width = context->Input(kBeamWidthInputIndex); + const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); + const Tensor* bias = context->Input(kBiasIndex); + + DecoderMaskedMultiHeadAttentionParams parameters; + + bool is_unidirectional = false; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, + key, + value, + bias, + mask_index, + attention_bias, + past_key, + past_value, + past_seq_len, + ¶meters, + num_heads_, + mask_filter_value_, + scale_, + is_unidirectional, + past_present_share_buffer_, + kDecoderMaskedMultiHeadAttention)); + + int batch_size = parameters.batch_size; + int sequence_length = parameters.sequence_length; + int head_size = parameters.head_size; + int v_head_size = parameters.v_head_size; + int hidden_size = parameters.hidden_size; + int v_hidden_size = parameters.v_hidden_size; + + // This kernel is for decoding only (i.e.) sequence length has to be 1 + if (sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. " + "Actual length is ", + sequence_length); + } + + if (head_size != v_head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "QK head size should be same as V head size to use DecoderMaskedMultiHeadAttention"); + } + + if (parameters.mask_type != AttentionMaskType::MASK_2D_KEY_PADDING && + parameters.mask_type != AttentionMaskType::MASK_NONE) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention only supports no mask or 2D key " + "padding mask of shape [batch, total_seq_length] currently"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + parameters.batch_size, parameters.num_heads, + past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length, + head_size}; + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); + Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* output_qk = nullptr; + + // Decoder cross-attention + if (past_key == nullptr && present_key == nullptr) { + if (attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention"); + } + + parameters.is_cross_attention = true; + parameters.total_sequence_length = parameters.kv_sequence_length; + parameters.max_sequence_length = parameters.kv_sequence_length; + } else { + // Sanity check + ORT_ENFORCE(past_present_share_buffer_); + ORT_ENFORCE(past_key != nullptr && past_value != nullptr); + + auto* present_key_data = present_key->MutableData(); + auto* present_value_data = present_value->MutableData(); + auto* past_key_data = past_key->Data(); + auto* past_value_data = past_value->Data(); + + if (present_key_data != past_key_data) { + std::memcpy(present_key_data, past_key_data, past_key->SizeInBytes()); + } + if (present_value_data != past_value_data) { + std::memcpy(present_value_data, past_value_data, past_value->SizeInBytes()); + } + + parameters.is_cross_attention = false; + } + + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + output_qk = context->Output(kQKOutputIndex, qk_shape); + } + + // Beam width (in case we are using this op inside BeamSearch) + int beam_width_value = 1; + if (beam_width != nullptr) { + beam_width_value = static_cast(*beam_width->Data()); + } + + // Cache indirection (in case we are using this op inside BeamSearch) + if (beam_width_value > 1 && cache_indir == nullptr) { + // If beam width > 1, then cache indirection buffer MUST be present + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "If beam width is greater than 1, then cache indirection buffer MUST be present"); + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + OrtValue Q; + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, head_size, query, bias, 0, Q)); + + // Cross-attention case + if (parameters.is_cross_attention) { + return ApplyAttention(Q.GetMutable()->MutableData(), + key->Data(), + value->Data(), + mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, + batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk); + } + + OrtValue K, V; + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, head_size, key, bias, hidden_size, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, v_head_size, value, bias, 2 * hidden_size, V)); + + // Self-attention, !has_beams + if (cache_indir == nullptr) { + return ApplyAttention(Q.GetMutable()->MutableData(), + K.GetMutable()->MutableData(), + V.GetMutable()->MutableData(), + mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, + batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk, + parameters.past_sequence_length, true /* past_present_share_buffer */); + } + + // Self-attention, has_beams + return ApplyAttentionWithBeams(Q.GetMutable()->MutableData(), + K.GetMutable()->MutableData(), + V.GetMutable()->MutableData(), + mask_index, past_key, past_value, output, present_key, present_value, + batch_size, parameters.past_sequence_length, parameters.max_sequence_length, + head_size, v_head_size, attention_bias, parameters.broadcast_attn_bias_dim_0, + parameters.broadcast_attn_bias_dim_1, cache_indir, context, + beam_width_value, output_qk); +} + +template +Status DecoderMaskedMultiHeadAttention::ApplyAttentionWithBeams( + const T* Q, + const T* K, + const T* V, + const Tensor* mask_index, + const Tensor* past_key, + const Tensor* past_value, + Tensor* output, + Tensor* present_key, + Tensor* present_value, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + int v_head_size, + const Tensor* attn_bias, + bool broadcast_attn_bias_dim_0, + bool broadcast_attn_bias_dim_1, + const Tensor* cache_indir, + OpKernelContext* context, + int beam_width, + Tensor* output_qk) const { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto* tp = context->GetOperatorThreadPool(); + + int total_sequence_length = past_sequence_length + 1; + size_t bytes = SafeInt(batch_size) * num_heads_ * total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + + T* output_qk_data = (output_qk != nullptr) ? output_qk->MutableData() : nullptr; + + const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data() : nullptr; + const T* attn_bias_data = attn_bias != nullptr ? attn_bias->Data() : nullptr; + + ComputeAttentionProbsWithBeams(static_cast(attention_probs), Q, K, mask_index_data, batch_size, + past_sequence_length, max_sequence_length, head_size, past_key->Data(), + present_key->MutableData(), tp, attn_bias_data, broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, cache_indir->Data(), beam_width, output_qk_data); + + // Compute the attentionScore * Value: out_tmp(B, N, 1, H_v) = attention_probs(B, N, 1, T) x V(B, N, T, H_v) + auto out_tmp_data = allocator->Alloc(SafeInt(batch_size) * num_heads_ * v_head_size * sizeof(T)); + BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); + + ComputeVxAttentionScoreWithBeams(output->MutableData(), static_cast(out_tmp_data), + static_cast(attention_probs), V, batch_size, + past_sequence_length, max_sequence_length, v_head_size, past_value->Data(), + present_value->MutableData(), cache_indir->Data(), beam_width, tp); + + return Status::OK(); +} + +template +void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( + T* attention_probs, + const T* Q, + const T* K, + const int32_t* mask_index_data, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + const T* past_key_data, + T* present_key_data, + ThreadPool* tp, + const T* attn_bias_data, + bool broadcast_attn_bias_dim_0, + bool broadcast_attn_bias_dim_1, + const int32_t* cache_indir_data, + int beam_width, + T* output_qk_data) const { + float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + TensorOpCost unit_cost; + auto total_sequence_length = past_sequence_length + 1; + const ptrdiff_t probs_matrix_size = total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); + + unit_cost.compute_cycles = static_cast((SafeInt(2) * head_size - 1) * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(2) * head_size * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(SafeInt(head_size) * total_sequence_length * sizeof(T)); + + if (attn_bias_data != nullptr) { + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes) * 2; + unit_cost.bytes_stored += probs_matrix_bytes; + } + + if (mask_index_data != nullptr) { + unit_cost.bytes_stored += probs_matrix_bytes; + } + + // Cost of appending current key to present key + unit_cost.compute_cycles += static_cast(head_size); + unit_cost.bytes_loaded += static_cast(head_size); + + // Parallel for loop + const int loop_len = batch_size * num_heads_; + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const std::ptrdiff_t batch_index = i / num_heads_; + const std::ptrdiff_t head_index = i % num_heads_; + const std::ptrdiff_t beam_batch_index = batch_index / beam_width; + const T* q_vec = Q + i * head_size; + const std::ptrdiff_t attn_bias_base_offset = ((broadcast_attn_bias_dim_0 ? 0 : (beam_batch_index * num_heads_)) + + (broadcast_attn_bias_dim_1 ? 0 : head_index)) * + probs_matrix_size; + + { + // Calculate the latest position of the attention_probs + // (1, H) x (T, H)^T -> (1, T) + // Decompose into T (1, H) x (1, H)^T -> (1, 1) operations + auto last_offset = past_sequence_length + i * probs_matrix_size; + T* attention_probs_ptr = reinterpret_cast(attention_probs) + last_offset; + math::Dot(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr); + + *attention_probs_ptr *= scale; + // Apply the attention bias and mask + if (attn_bias_data != nullptr) { + *attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length]; + } + bool is_masked = (mask_index_data != nullptr) && + (mask_index_data[(batch_index + 1) * total_sequence_length - 1] == 0); + if (is_masked) { + *attention_probs_ptr += mask_filter_value_; + } + } + + { + // Calculate the rest of the attention_probs + for (std::ptrdiff_t j = 0; j < past_sequence_length; ++j) { + const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length]; + const std::ptrdiff_t beam_offset = static_cast(beam_indices[j]) * num_heads_ * + max_sequence_length * head_size; + const std::ptrdiff_t beam_batch_offset = (beam_batch_index * beam_width * num_heads_ + head_index) * + max_sequence_length * head_size; + const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size; + T* output = reinterpret_cast(attention_probs) + j + i * probs_matrix_size; + math::Dot(head_size, q_vec, past_k_vec, output, nullptr); + + *output *= scale; + // Apply the attention bias and mask + if (attn_bias_data != nullptr) { + *output += attn_bias_data[attn_bias_base_offset + j]; + } + bool is_masked = (mask_index_data != nullptr) && + (mask_index_data[batch_index * total_sequence_length + j] == 0); + if (is_masked) { + *output += mask_filter_value_; + } + } + } + // Append current key to present key (past_present_share_buffer_ is true) + memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size, + K + i * head_size, head_size * sizeof(T)); + } + }); + + if (output_qk_data != nullptr) { + // Output the scaled Q*K^T if needed. + memcpy(output_qk_data, attention_probs, + SafeInt(batch_size) * num_heads_ * total_sequence_length * sizeof(T)); + } + + // attention_probs(B, N, 1, T) = Softmax(attention_probs) + { + const int N = batch_size * num_heads_; + const int D = total_sequence_length; + ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp); + } +} + +template +void DecoderMaskedMultiHeadAttention::ComputeVxAttentionScoreWithBeams( + T* output, + T* tmp_buffer, + const T* attention_probs, + const T* V, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int v_head_size, + const T* past_value_data, + T* present_value_data, + const int32_t* cache_indir_data, + int beam_width, + ThreadPool* tp) const { + const int total_sequence_length = past_sequence_length + 1; + + TensorOpCost unit_cost; + unit_cost.compute_cycles = static_cast(SafeInt(2) * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(3) * v_head_size * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(SafeInt(2) * v_head_size * total_sequence_length * sizeof(T)); + + // Cost of appending current value to present value + unit_cost.compute_cycles += static_cast(v_head_size); + unit_cost.bytes_loaded += static_cast(v_head_size); + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const std::ptrdiff_t batch_index = i / num_heads_; + const std::ptrdiff_t head_index = i % num_heads_; + const std::ptrdiff_t beam_batch_index = batch_index / beam_width; + + // Compute the attention score + // (1, T) x (T, H_v) -> (1, H_v) + // Decompose into T (1, 1) x (1, H_v) -> (1, H_v) operations and accumulate. + { + const T* attn_probs_ptr = attention_probs + (i + 1) * total_sequence_length - 1; + math::Scale(v_head_size, + static_cast(*attn_probs_ptr), + V + i * v_head_size, + output + i * v_head_size, + nullptr); + } + { + for (std::ptrdiff_t j = 0; j < past_sequence_length; ++j) { + const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length]; + const std::ptrdiff_t beam_offset = static_cast(beam_indices[j]) * num_heads_ * + max_sequence_length * v_head_size; + const std::ptrdiff_t beam_batch_offset = (beam_batch_index * beam_width * num_heads_ + head_index) * + max_sequence_length * v_head_size; + const T* past_value_vec = past_value_data + beam_offset + beam_batch_offset; + const T* attn_probs_ptr = attention_probs + j + i * total_sequence_length; + + math::Scale(v_head_size, + static_cast(*attn_probs_ptr), + past_value_vec + j * v_head_size, + tmp_buffer + i * v_head_size, + nullptr); + math::Add(v_head_size, + output + i * v_head_size, + tmp_buffer + i * v_head_size, + output + i * v_head_size, + nullptr); + } + } + // Append current value to present value (past_present_share_buffer_ is true) + memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size, + V + i * v_head_size, + v_head_size * sizeof(T)); + } + }); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h new file mode 100644 index 0000000000000..d5167e8989669 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionCPUBase { + public: + DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); + Status ApplyAttentionWithBeams(const T* Q, + const T* K, + const T* V, + const Tensor* mask_index, + const Tensor* past_key, + const Tensor* past_value, + Tensor* output, + Tensor* present_key, + Tensor* present_value, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + int v_head_size, + const Tensor* attn_bias, + bool broadcast_attn_bias_dim_0, + bool broadcast_attn_bias_dim_1, + const Tensor* cache_indir, + OpKernelContext* context, + int beam_width, + Tensor* output_qk = nullptr) const; + void ComputeAttentionProbsWithBeams(T* attention_probs, + const T* Q, + const T* K, + const int32_t* mask_index_data, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + const T* past_key, + T* present_key, + ThreadPool* tp, + const T* attn_bias_data, + bool broadcast_attn_bias_dim_0, + bool broadcast_attn_bias_dim_1, + const int32_t* cache_indir_data, + int beam_width, + T* output_qk_data = nullptr) const; + void ComputeVxAttentionScoreWithBeams(T* output, + T* tmp_buffer, + const T* attention_probs, + const T* V, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int v_head_size, + const T* past_value, + T* present_value, + const int32_t* cache_indir_data, + int beam_width, + ThreadPool* tp) const; + Status Compute(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; + bool past_present_share_buffer_; + bool output_qk_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0bdee151d2173..4cc5a4228dc8c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -11,18 +11,19 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query, return Status::OK(); } -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap, int max_threads_per_block) { diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index cbfd2f0949363..9a6c2af022c91 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/bert/rotary_embedding.h" #include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" using onnxruntime::concurrency::ThreadPool; @@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; - int cache_idx = 0; - bool sign = false; - int j = 0; - for (int i = 0; i < rotary_emb_dim; i++) { - if (interleaved) { - cache_idx = (i / 2) % half_rotary_emb_dim; - sign = i & 1; - j = sign ? i - 1 : i + 1; // i - sign - } else { - cache_idx = i % half_rotary_emb_dim; - sign = (i >= half_rotary_emb_dim); - j = (i + half_rotary_emb_dim) % rotary_emb_dim; - } - float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); - float input_data_j = static_cast(input_data[j]); - float sin_data_cache_idx = static_cast(sin_data[cache_idx]); - if (sign) { - output_data_i += input_data_j * sin_data_cache_idx; - } else { - output_data_i -= input_data_j * sin_data_cache_idx; - } - output_data[i] = static_cast(output_data_i); - } - for (int i = rotary_emb_dim; i < head_size; i++) { - output_data[i] = input_data[i]; + MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); + + if (rotary_emb_dim < head_size) { + std::memcpy(output_data + rotary_emb_dim, + input_data + rotary_emb_dim, + (head_size - rotary_emb_dim) * sizeof(T)); } } }); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 6ffe861d19931..c742cd1e95bdd 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -61,6 +61,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastG class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); @@ -148,6 +149,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -333,6 +336,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to main backward compatibility BuildKernelCreateInfo, @@ -358,6 +362,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 89e96543c4729..c3e43f897c509 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -32,24 +32,47 @@ constexpr size_t A = 0, bias = 5; }; -int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { - const auto accuracy_level = std::clamp(accuracy_level_attr, - static_cast(CompMostAccurate), - static_cast(CompLeastAccurate)); - - // Find a supported accuracy level that is not less accurate than the one given. - // CompMostAccurate is always supported with the fallback implementation. - // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. - int64_t effective_accuracy_level = accuracy_level; - for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { - const auto compute_type = static_cast(effective_accuracy_level); - if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { - break; - } +typedef enum { + Level0, /*!< input fp32, accumulator fp32 */ + Level1, /*!< input fp32, accumulator fp32 */ + Level2, /*!< input fp16, accumulator fp16 */ + Level3, /*!< input bf16, accumulator fp32 */ + Level4, /*!< input int8, accumulator int32 */ +} ACCURACY_LEVEL; + +// T: A data type. +template +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp32, only accuracy level 1 or 4 makes sense. + // non-ARM CPU converts Fp16 to Fp32. + // By converting Fp32 to Fp16, precision becomes worse. And due to the casting, + // there is no performance gain. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, SQNBIT_CompInt8)) { + return SQNBIT_CompInt8; } - return effective_accuracy_level; + return SQNBIT_CompFp32; } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +template <> +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp16, only accuracy level 2 or 4 makes sense. + // By converting Fp16 to Fp32, there is not precision increase, and the performance + // becomes worse. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, HQNBIT_CompInt8)) { + return HQNBIT_CompInt8; + } + + // if HQNBIT_CompFp16 is not supported, will fallback to unpacked computation. + return HQNBIT_CompFp16; +} +#endif // !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 + } // namespace bool GetType(const NodeArg& node_arg, int32_t& type) { @@ -74,10 +97,9 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, - compute_type_{static_cast(accuracy_level_)} { + compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); const NodeArg* zero_point_arg = @@ -109,10 +131,9 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; - const int64_t accuracy_level_; const bool has_g_idx_; const bool has_bias_; - const MLAS_SQNBIT_GEMM_COMPUTE_TYPE compute_type_; + const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; @@ -143,9 +164,7 @@ class MatMulNBits final : public OpKernel { Tensor* y, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - ORT_THROW("ComputeBPacked is not supported for T1 type."); - } + const MatMulComputeHelper& helper) const; }; template @@ -158,28 +177,28 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 @@ -188,6 +207,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) +// Non-ARM-with-fp16-intrinsics fall back fp16 to fp32. template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -211,29 +232,29 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), - nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 @@ -241,6 +262,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou return Status::OK(); } +#endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, @@ -255,20 +277,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& return Status::OK(); } -template <> -Status MatMulNBits::ComputeBPacked(const Tensor* a, - const Tensor* scales, - const Tensor* zero_points, - const Tensor* bias, - Tensor* y, - AllocatorPtr& allocator, - concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - const auto* a_data = a->Data(); - const auto* scales_data = scales->Data(); +template +Status MatMulNBits::ComputeBPacked(const Tensor* a, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + Tensor* y, + AllocatorPtr& allocator, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); - auto* y_data = y->MutableData(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -277,19 +299,19 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -300,11 +322,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) template <> Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -327,7 +350,7 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed @@ -361,12 +384,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, size_t c_size = static_cast(y->Shape().Size()); std::vector c_v(c_size); - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -377,11 +400,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } +#endif // end of !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_AMD64 template <> Status MatMulNBits::ComputeBUnpacked(const Tensor* a, @@ -517,9 +541,10 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t ldb = helper.Ldb(true); float* scales_ptr = nullptr; + IAllocatorUniquePtr temp_scales; if (!scales_fp32_) { auto scales_size = static_cast(scales->Shape().Size()); - auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); scales_ptr = temp_scales.get(); } else { @@ -600,8 +625,9 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (bias) { float* bias_ptr = nullptr; const size_t bias_size = static_cast(bias->Shape().Size()); + IAllocatorUniquePtr bias_temp; if (!bias_fp32_) { - auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); bias_ptr = bias_temp.get(); } else { @@ -654,11 +680,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // clang-format on if (has_single_b_matrix && - packed_b_) { // Assume that MlasSQNBitGemmBatch() always requires packed B. - // If this changes, i.e., if MlasIsSQNBitGemmAvailable() can return true while - // MlasSQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasSQNBitGemmBatch() + packed_b_) { // Assume that MlasQNBitGemmBatch() always requires packed B. + // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while + // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index faf78cae80ee1..d5b8961cf8c5a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/mlas/inc/mlas.h" #include "core/util/math_cpuonly.h" #include "core/providers/common.h" #include "core/platform/threadpool.h" @@ -36,52 +37,87 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) -// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. -template -ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); - -template <> -ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return val.ToFloat(); -} - -template <> -ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return static_cast(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); -} - -template <> -ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { - return val; -} - -template <> -ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { - return val; +namespace { + +template || std::is_same_v, void>> +void ComputeJob( + const T* input_data, + const T* skip_data, + const T* gamma_data, + const T* beta_data, + const T* bias_data, + ptrdiff_t task_idx, + int hidden_size, + int64_t skip_size, + float epsilon, + bool simplified, + T* output_data, + T* skip_input_bias_add_output_data) { + auto offset = task_idx * hidden_size; + const T* p_input = input_data + offset; + const T* p_skip = skip_data + (offset % skip_size); + T* p_output = output_data + offset; + T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; + + T mean(0.0f); + T mean_square(0.0f); + + for (decltype(hidden_size) h = 0; h < hidden_size; h++) { + T val = p_input[h] + p_skip[h]; + + if (nullptr != bias_data) { + val += bias_data[h]; + } + + if (nullptr != p_skip_input_bias_add_output) { + p_skip_input_bias_add_output[h] = val; + } + + p_output[h] = val; + mean += val; + mean_square += val * val; + } + + mean = mean / hidden_size; + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); + } + + for (decltype(hidden_size) h = 0; h < hidden_size; h++) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { + p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; + } else { + p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; + } + } } -// Function template that only converts the input value to MLFloat16 if T is MLFloat16. -template -ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) { - return val; -} +void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { + if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { + auto tensor_data_ptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto float_ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) { - return MLFloat16(val); + MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size); + dest = std::move(float_ptr); + is_packed = true; + } } -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) { - return MLFloat16(static_cast(val)); -} +} // namespace template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info) { + : OpKernel(op_kernel_info), + prepacked_skip_fp32_size_(0), + prepacked_skip_fp32_data_(nullptr), + prepacked_gamma_fp32_data_(nullptr), + prepacked_beta_fp32_data_(nullptr), + prepacked_bias_fp32_data_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -89,103 +125,153 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) template Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); - const Tensor* skip = p_ctx->Input(1); - const Tensor* gamma = p_ctx->Input(2); - const Tensor* beta = p_ctx->Input(3); - const Tensor* bias = p_ctx->Input(4); + const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); + const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3)); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(simplified ? 3 : 4); Tensor* output = p_ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors + // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); int hidden_size = static_cast(input_dims[input_dims_size - 1]); - ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, - skip, - gamma, - beta, - bias, - hidden_size, - input_dims_size)); + ORT_RETURN_IF_ERROR(skip_layer_norm_helper::CheckPotentiallyPrepackedInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size, + prepacked_skip_fp32_data_ != nullptr, + prepacked_gamma_fp32_data_ != nullptr)); int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); const T* input_data = input->Data(); - const T* skip_data = skip->Data(); - const T* gamma_data = gamma->Data(); + const T* skip_data = skip == nullptr ? nullptr : skip->Data(); + const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data(); const T* beta_data = beta == nullptr ? nullptr : beta->Data(); const T* bias_data = bias == nullptr ? nullptr : bias->Data(); T* output_data = output->MutableData(); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData() : nullptr; - - const auto& skip_size = skip->Shape().Size(); - - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(task_count), - [&](ptrdiff_t task_idx) { - auto offset = task_idx * hidden_size; - - const T* p_input = input_data + offset; - const T* p_skip = skip_data + (offset % skip_size); - T* p_output = output_data + offset; - T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr; - - using DoubleOrFloat = typename std::conditional< - std::is_same::value, // If T is double - double, // Use double - float // Otherwise, use float (covers float and MLFloat16) - >::type; - - DoubleOrFloat mean(0.0f); - DoubleOrFloat mean_square(0.0f); - - std::unique_ptr output_buffer = std::make_unique(hidden_size); - for (size_t h = 0; h < static_cast(hidden_size); h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_skip[h]); - - DoubleOrFloat value = input_value + skip_value; - - if (nullptr != bias_data) { - value += ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); - } - - output_buffer[h] = value; - T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded(value); - if (nullptr != p_skip_input_bias_add_output_data) { - p_skip_input_bias_add_output_data[h] = converted_value; - } - - mean += value; - mean_square += value * value; - } - - mean = mean / hidden_size; - if (simplified) { - mean_square = sqrt(mean_square / hidden_size + epsilon_); - } else { - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); - } - - for (size_t h = 0; h < static_cast(hidden_size); h++) { - DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(gamma_data[h]); - if (simplified) { - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded(output_buffer[h] / mean_square * gamma_value); - } else if (nullptr == beta_data) { - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value); - } else { - DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(beta_data[h]); - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value + beta_value); - } - } - }, - 0); + // For inferencing, we support one more optional output which is the sum of the input and skip tensors + T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); + const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_; + + if constexpr (std::is_same_v) { + const size_t total_data_size = static_cast(input->Shape().Size()); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + + IAllocatorUniquePtr input_fp32; + IAllocatorUniquePtr output_fp32; + IAllocatorUniquePtr skip_input_bias_add_output_fp32; + IAllocatorUniquePtr skip_fp32; + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + + const float* input_data_f = nullptr; + const float* skip_data_f = nullptr; + const float* gamma_data_f = nullptr; + const float* beta_data_f = nullptr; + const float* bias_data_f = nullptr; + float* output_data_f = nullptr; + float* skip_input_bias_add_output_data_f = nullptr; + + const size_t num_elems = static_cast(hidden_size); + + input_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size); + input_data_f = input_fp32.get(); + + output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + output_data_f = output_fp32.get(); + + skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get(); + + if (skip_data) { + skip_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(skip_size)); + MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast(skip_size)); + skip_data_f = skip_fp32.get(); + } else if (prepacked_skip_fp32_data_) { + skip_data_f = prepacked_skip_fp32_data_.get(); + } + + if (gamma_data) { + gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + gamma_data_f = gamma_fp32.get(); + } else if (prepacked_gamma_fp32_data_) { + gamma_data_f = prepacked_gamma_fp32_data_.get(); + } + + if (beta_data) { + beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + beta_data_f = beta_fp32.get(); + } else if (prepacked_beta_fp32_data_) { + beta_data_f = prepacked_beta_fp32_data_.get(); + } + + if (bias_data) { + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + bias_data_f = bias_fp32.get(); + } else if (prepacked_bias_fp32_data_) { + bias_data_f = prepacked_bias_fp32_data_.get(); + } + + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f); + }, + 0); + MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size); + if (skip_input_bias_add_output_data != nullptr) + MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size); + } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data, skip_input_bias_add_output_data); + }, + 0); + } + + return Status::OK(); +} + +template +Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + is_packed = false; + if (input_idx == 1) { // skip + prepacked_skip_fp32_size_ = tensor.Shape().Size(); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); + } else if (input_idx == 2) { // gamma + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); + } else if (input_idx == 3) { + if constexpr (simplified) { + // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } else { + // beta + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } + } else if (input_idx == 4) { // bias + ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5."); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 69edf4609e340..4a350fdcc2220 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,8 +16,16 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: float epsilon_; + int64_t prepacked_skip_fp32_size_; + IAllocatorUniquePtr prepacked_skip_fp32_data_; + IAllocatorUniquePtr prepacked_gamma_fp32_data_; + IAllocatorUniquePtr prepacked_beta_fp32_data_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h index 6271f822287e6..4c901f5650dbd 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h @@ -11,14 +11,10 @@ namespace onnxruntime { namespace contrib { namespace skip_layer_norm_helper { +namespace { + template -Status CheckInputs(const T* input, - const T* skip, - const T* gamma, - const T* beta, - const T* bias, - int hidden_size_check, - size_t input_dims_size_check) { +Status CheckSkip(const T* input, const T* skip, size_t input_dims_size_check) { const auto& input_dims_check = input->Shape().GetDims(); const auto& skip_dims_check = skip->Shape().GetDims(); size_t skip_dims_size_check = skip_dims_check.size(); @@ -33,49 +29,150 @@ Status CheckInputs(const T* input, "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions"); } - if (input_dims_size_check != 3 && input_dims_size_check != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); - } - if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last two dimensions of skip needs to be same as input"); } + return Status::OK(); +} + +template +Status CheckGamma(const T* gamma, int hidden_size_check) { const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } + if (gamma_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of gamma and input does not match"); } + return Status::OK(); +} + +template +Status CheckBeta(const T* beta, int hidden_size_check) { if (nullptr != beta) { const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } + if (beta_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of beta and input does not match"); } } + return Status::OK(); +} + +template +Status CheckBias(const T* bias, int hidden_size_check) { if (nullptr != bias) { const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bias is expected to have 1 dimension, got ", bias_dims.size()); } + if (bias_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of bias and input does not match"); } } + + return Status::OK(); +} + +} // anonymous namespace + +template +Status CheckInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + return Status::OK(); +} + +template +Status CheckPotentiallyPrepackedInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check, + bool prepacked_skip, + bool prepacked_gamma) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + if (nullptr != skip) { + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "skip is expected but not provided"); + } + + if (nullptr != gamma) { + auto status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_gamma) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected but not provided"); + } + + auto status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc new file mode 100644 index 0000000000000..9f1d4d6e20307 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/tensor/dynamic_time_warping.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + DynamicTimeWarping, + kMSDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("F", DataTypeImpl::GetTensorType()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + DynamicTimeWarping); + +Status DynamicTimeWarping::Compute(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), + "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); + + const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]); + const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]); + + std::vector> cost(rows + 1, std::vector(cols + 1, std::numeric_limits::infinity())); + std::vector> trace(rows + 1, std::vector(cols + 1, -1)); + std::vector> path_helper; + + // Compute the cost and trace matrices + cost[0][0] = 0; + for (size_t j = 1; j <= cols; ++j) { + for (size_t i = 1; i <= rows; ++i) { + const float c0 = cost[i - 1][j - 1]; + const float c1 = cost[i - 1][j]; + const float c2 = cost[i][j - 1]; + + float cur_cost; + int8_t cur_trace; + if (c0 < c1 && c0 < c2) { + cur_cost = c0; + cur_trace = 0; + } else if (c1 < c0 && c1 < c2) { + cur_cost = c1; + cur_trace = 1; + } else { + cur_cost = c2; + cur_trace = 2; + } + + cost[i][j] = cur_cost + input_tensor.Data()[(i - 1) * cols + j - 1]; + trace[i][j] = cur_trace; + } + } + + // Back-tracing to find the optimal path + int i = static_cast(rows); + int j = static_cast(cols); + int result_len = 0; + while (i > 0 && j > 0) { + path_helper.push_back({i - 1, j - 1}); + ++result_len; + int8_t cur_trace = trace[i][j]; + switch (cur_trace) { + case 0: + --i; + --j; + break; + case 1: + --i; + break; + case 2: + --j; + break; + default: + ORT_THROW("Invalid trace value: ", cur_trace); + } + } + + // Update the output tensor + Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)}); + auto* output_data = output_tensor->MutableData(); + for (int k = 0; k < result_len; ++k) { + output_data[k] = path_helper[static_cast(result_len) - k - 1][0]; + output_data[k + result_len] = path_helper[static_cast(result_len) - k - 1][1]; + } + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h new file mode 100644 index 0000000000000..76083d426a58a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/op_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; + +class DynamicTimeWarping : public OpKernel { + public: + DynamicTimeWarping(const OpKernelInfo& info) : OpKernel(info) {} + + ~DynamicTimeWarping() = default; + + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/tensor/unfold.cc b/onnxruntime/contrib_ops/cpu/tensor/unfold.cc new file mode 100644 index 0000000000000..edafa538be219 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/unfold.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/tensor/unfold.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + UnfoldTensor, + kMSDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + UnfoldTensor); + +template +Status LaunchUnfoldTensor(const T* input, + T* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size, + concurrency::ThreadPool* tp) { + int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; + int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; + + int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; + int64_t stride_fold_dim_src = tailing_dims_size * step_size; + int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; + + static constexpr double cost = 1.0; + concurrency::ThreadPool::TryParallelFor(tp, static_cast(N), cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int64_t idx = static_cast(i); + const int64_t idx_leading = idx / stride_leading_dst; + int64_t n = idx % stride_leading_dst; + const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; + const int64_t idx_fold = n / stride_fold_dim_dst; + n %= stride_fold_dim_dst; + const int64_t idx_tailing = n / unfold_size; + const int64_t idx_append = n % unfold_size; + + int64_t idx_src = idx_leading * stride_leading_src + + idx_fold * stride_fold_dim_src + idx_tailing + + idx_append * tailing_dims_size; + output[idx] = input[idx_src]; + } + }); + + return Status::OK(); +} + +Status UnfoldTensor::Compute(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + + int dim = SafeInt(HandleNegativeAxis(dim_, rank)); + ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); + ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); + + int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + static_cast(dim), + 1LL, std::multiplies()); + int64_t tailing_dims = std::accumulate(input_dims.begin() + (static_cast(dim) + 1), + input_dims.end(), 1LL, std::multiplies()); + + std::vector output_dims(static_cast(rank) + 1, 0); + std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); + output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; + output_dims.back() = size_; + TensorShape output_shape(output_dims); + Tensor* output_tensor = ctx->Output(0, output_shape); + + auto* tp = ctx->GetOperatorThreadPool(); + + Status status; + if (input_tensor.IsDataType()) { + status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(), + leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); + } else if (input_tensor.IsDataType()) { + status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(), + leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); + } else if (input_tensor.IsDataType()) { + status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(), + leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); + } else if (input_tensor.IsDataType()) { + status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(), + leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported data type: ", input_tensor.DataType()); + } + + return status; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/tensor/unfold.h b/onnxruntime/contrib_ops/cpu/tensor/unfold.h new file mode 100644 index 0000000000000..6c48d0f67fcc2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/unfold.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/op_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; + +template +Status LaunchUnfoldTensor( + const T* input, + T* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size); + +class UnfoldTensor final : public OpKernel { + public: + UnfoldTensor(const OpKernelInfo& info) : OpKernel(info) { + dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL)); + step_ = SafeInt(info.GetAttrOrDefault("step", 1LL)); + ORT_ENFORCE(step_ > 0, "step must greater than zero!"); + + int64_t temp_size; + ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); + size_ = SafeInt(temp_size); + } + + ~UnfoldTensor() = default; + + Status Compute(OpKernelContext* context) const override; + + private: + int dim_; + int size_; + int step_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 30bf3aa0a1212..8145fbd4a4123 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -100,6 +100,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 440a07e14a6cc..7dd1f28d270c7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index d675ba742e03b..7757435990a65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -31,6 +31,7 @@ Subgraph::Subgraph( allocator_(nullptr), is_output_float16_(false) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + used_implicit_inputs = std::vector(num_implicit_inputs, true); auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state, // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + + const auto& implicit_input_defs = node.ImplicitInputDefs(); + for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) { + const auto* entry = implicit_input_defs[i]; + int idx; + if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) { + feed_names.push_back(entry->Name()); + } else { + --num_implicit_inputs; + used_implicit_inputs[i] = false; + } } InlinedVector feed_locations; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index bde591626bb83..8ec9c9cbdc20f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -31,6 +31,7 @@ class Subgraph { const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; + std::vector used_implicit_inputs; int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. int num_subgraph_outputs; // Same as subgraph_output_names.size() diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9037e58aaf31f..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + const size_t cur_len_bytes = cur_len * sizeof(int); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + ptrdiff_t seq_index = static_cast(i) * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. @@ -230,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -238,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); @@ -281,8 +299,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } // Pass through implicit inputs. - for (const auto* entry : implicit_inputs) { - decoder_feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + decoder_feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 51473c0c931b9..d59db4afac2c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds( pinned_allocator, location)); - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index efbc0b5031657..22e2879a5be15 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -102,6 +102,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && (nullptr == attention_bias) && @@ -118,21 +121,26 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention = false; } // Allocate buffers + size_t softmax_lse_bytes = 0; size_t softmax_lse_accum_bytes = 0; size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif @@ -247,6 +255,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; constexpr bool use_cudnn_flash_attention = false; + constexpr bool use_lean_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -257,14 +266,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { data.bias = reinterpret_cast(bias->Data()); @@ -289,6 +297,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eff58c0080012..9e017544d7cff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -39,6 +39,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" #include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; @@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize( #if USE_FLASH_ATTENTION if (use_flash_attention) { - return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes; } #else ORT_UNUSED_PARAMETER(use_flash_attention); #endif +#if USE_LEAN_ATTENTION + if (use_lean_attention) { + return qkv_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_lean_attention); +#endif + #if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; @@ -301,10 +311,10 @@ Status FlashAttention( constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.softmax_lse), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16, - false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + false, data.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); @@ -326,6 +336,81 @@ Status FlashAttention( } #endif +#if USE_LEAN_ATTENTION +template +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + assert(nullptr == data.mask_index); + assert(nullptr == data.attention_bias); + assert(parameters.head_size == parameters.v_head_size); + + constexpr bool is_bf16 = false; + + ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache( + device_prop, stream, + data.q, + data.k, // k_cache + data.v, // v_cache + nullptr, // new_k (we have appended new_k to k_cache) + nullptr, // new_v (we have appended new_v to k_cache) + data.output, + reinterpret_cast(data.softmax_lse), + nullptr, // seqlens_k + nullptr, // cos_cache + nullptr, // sin_cache + nullptr, // block_table + parameters.batch_size, + parameters.num_heads, + parameters.num_heads, // num_heads_k + parameters.head_size, + parameters.sequence_length, // seqlen_q + parameters.total_sequence_length, // seqlen_k + 0, // seqlen_k_new + 0, // rotary_dim + scale, // softmax_scale + parameters.is_unidirectional, + is_bf16, + false, // past_bsnh + data.num_splits, + data.grid_dim_z, + data.max_tiles_per_tb, + data.high_load_tbs, + data.tiles_per_head, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), + data.lean_sync_flag, + -1, // local_window_size + false, // is_rotary_interleaved + false // is_packed_qkv + )); + + return Status::OK(); +} + +template <> +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + ORT_UNUSED_PARAMETER(device_prop); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor"); +} +#endif + + + template Status CudnnFlashAttention( cudnnHandle_t cudnn_handle, @@ -641,6 +726,11 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; +#if USE_LEAN_ATTENTION + if (data.use_lean_attention) { + return LeanAttention(device_prop, stream, parameters, data, scale); + } +#endif #if USE_FLASH_ATTENTION if (data.use_flash_attention) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fcc9af9681223..7d111a1ee21bf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -102,6 +103,19 @@ struct AttentionData { T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + // Flash Atttention and Lean Attention + int num_splits; + + // Lean Attention + bool use_lean_attention = false; +#if USE_LEAN_ATTENTION + int grid_dim_z = 0; + int max_tiles_per_tb = 0; + int high_load_tbs = 0; + int tiles_per_head = 0; + int* lean_sync_flag = nullptr; +#endif + // For Debugging size_t workspace_bytes = 0; bool allow_debug_info = false; @@ -115,6 +129,7 @@ struct AttentionData { void PrintDebugInfo() const { std::cout << "flash=" << use_flash_attention + << ", lean=" << use_lean_attention << ", efficient=" << use_memory_efficient_attention << ", fused_runner=" << (fused_runner != nullptr) << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 7d21451df5b86..8b8b764e7c785 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -17,6 +17,9 @@ namespace onnxruntime { void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) { if (value > 0) { use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; +#if USE_LEAN_ATTENTION + use_lean_attention_ = (value & static_cast(AttentionBackend::LEAN_ATTENTION)) > 0; +#endif use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; @@ -26,6 +29,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; } else { use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); +#if USE_LEAN_ATTENTION + use_lean_attention_ = ParseEnvironmentVariableWithDefault(kEnableLeanAttention, false); +#endif use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); @@ -61,6 +67,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_flash_attention_ = false; #endif +#ifndef USE_LEAN_ATTENTION + use_lean_attention_ = false; +#endif + #ifndef USE_MEMORY_EFFICIENT_ATTENTION use_efficient_attention_ = false; #endif @@ -81,6 +91,9 @@ void AttentionKernelOptions::Print() const { std::stringstream sstream; sstream << "AttentionKernelOptions:"; sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); +#if USE_LEAN_ATTENTION + sstream << " LEAN_ATTENTION=" << int(use_lean_attention_); +#endif sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); @@ -131,6 +144,10 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, sstream << " SdpaKernel="; if (use_flash_attention.has_value() && use_flash_attention.value()) { sstream << "FLASH_ATTENTION"; +#if USE_LEAN_ATTENTION + } else if (use_lean_attention.has_value() && use_lean_attention.value()) { + sstream << "LEAN_ATTENTION"; +#endif } else if (use_efficient_attention.has_value() && use_efficient_attention.value()) { sstream << "EFFICIENT_ATTENTION"; } else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index a27fb199a6272..caed704564c3b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -9,6 +9,7 @@ namespace onnxruntime { struct AttentionKernelDebugInfo { std::optional use_flash_attention = std::nullopt; + std::optional use_lean_attention = std::nullopt; std::optional use_efficient_attention = std::nullopt; std::optional use_trt_fused_attention = std::nullopt; std::optional use_cudnn_flash_attention = std::nullopt; @@ -24,6 +25,7 @@ class AttentionKernelOptions { void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false); bool UseFlashAttention() const { return use_flash_attention_; } + bool UseLeanAttention() const { return use_lean_attention_; } bool UseEfficientAttention() const { return use_efficient_attention_; } bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } @@ -44,6 +46,7 @@ class AttentionKernelOptions { private: bool use_flash_attention_{true}; + bool use_lean_attention_{false}; bool use_efficient_attention_{true}; bool use_trt_fused_attention_{true}; bool use_cudnn_flash_attention_{false}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a079076f2881b..c8c0191967d40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -384,6 +384,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, if (data.use_memory_efficient_attention || data.use_flash_attention || + data.use_lean_attention || data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Use oiginal Query (BSNH) since there is no bias. data.q = const_cast(data.query); diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 8edae863ff44e..e4c1659c0fb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -298,6 +298,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.attention_bias != nullptr) { qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + tlength]); } + if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) { + qk += params.mask_filter_value; + } qk_max = qk; qk_smem[tlength] = qk; } @@ -534,7 +537,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.out_qk != nullptr) { // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] - float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + float* target = (reinterpret_cast(params.out_qk)) + (static_cast(bhi) * (sum_tlength + 1)); for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { target[ti] = (float)(qk_smem[ti]); } diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index efad33855328f..0e1c9ce7b108e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -10,45 +10,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { - int beam_width = 1; - - // Only NeoX style rotary embedding is supported - int rotary_embedding_dim = 0; - int t_step = 0; - - // Weather to use multihead attention(excludes matmul and bias) - bool is_mha = false; - bool is_cross_attention = false; - bool is_packed_qkv = false; - - // Useful to better use global memory bandwidth on certain CUDA architectures. - // Turned off by default for now until we fully understand performance implications - // for all types of workloads. - // Can be turned on by appropriate environment variable (see attention_common.h). - bool kv_data_in_flight = false; - - void* q = nullptr; - void* q_bias = nullptr; - - void* k = nullptr; - void* k_bias = nullptr; - - void* v = nullptr; - void* v_bias = nullptr; - - void* attention_bias = nullptr; - - void* k_cache = nullptr; - void* v_cache = nullptr; - - void* out = nullptr; - void* out_qk = nullptr; - - const int32_t* cache_indir = nullptr; - const int32_t* mask = nullptr; // [B, total_sequence_length] -}; - template < // The type of the inputs. Supported types: float and half. typename T, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index e961bab399326..d46d9597a758f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -98,7 +98,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSE(row) = INFINITY; + gLSE(row) = std::numeric_limits::infinity(); } } return; @@ -499,7 +499,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSEaccum(row) = Split ? -INFINITY : INFINITY; + gLSEaccum(row) = Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); } } return; @@ -1061,7 +1061,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -std::numeric_limits::infinity(); if (row < kMaxSplits) { sLSE[row][col] = lse; } @@ -1082,7 +1082,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -std::numeric_limits::infinity(); // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } } @@ -1094,7 +1094,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { @@ -1104,7 +1104,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h index 0998155eba635..71434002f8df1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h @@ -4,6 +4,7 @@ #pragma once +#include #include namespace onnxruntime { @@ -28,7 +29,7 @@ __forceinline__ __device__ void apply_mask(Tensor& tensor, const // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -59,7 +60,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor& tensor, for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -96,7 +97,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; + tensor(mi, ni) = -std::numeric_limits::infinity(); } } // if (cute::thread0()) { @@ -151,7 +152,7 @@ struct Mask { } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -181,18 +182,18 @@ struct Mask { } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 7e0095cb39bd9..7fe506e01a9b9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include @@ -71,7 +72,9 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float max_scaled = max(mi) == -std::numeric_limits::infinity() + ? 0.f + : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - @@ -99,7 +102,7 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor& ten max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { @@ -143,7 +146,7 @@ struct Softmax { for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll @@ -169,7 +172,9 @@ struct Softmax { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + lse(mi) = (sum == 0.f || sum != sum) + ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) + : row_max(mi) * softmax_scale + __logf(sum); float scale = inv_sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h new file mode 100644 index 0000000000000..6d9ed824b4b76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h @@ -0,0 +1,45 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h new file mode 100644 index 0000000000000..a2058d8805ebd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h @@ -0,0 +1,148 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace lean { + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void* __restrict__ p_ptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int* __restrict__ seqused_k; + + int* __restrict__ blockmask; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr; + void* __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx; + + // Paged KV cache + int* __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t* rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version and lean + + void* __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + // LEAN Additional Params + int lean_griddimz; + int tiles_per_head; + int max_tiles_per_tb; + int high_load_tbs; + void* __restrict__ sync_flag; + + const cudaDeviceProp* dprops = nullptr; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h new file mode 100644 index 0000000000000..85be5d3e031ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h @@ -0,0 +1,315 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementAccum); + // static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize + kSmemOSize; + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype(composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutQdOtransposed = decltype(composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc new file mode 100644 index 0000000000000..81301ebc7ba64 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc @@ -0,0 +1,453 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Modifications: support lean attention. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" +#include + +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" + +namespace onnxruntime { +namespace lean { + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* seqused_k, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data +#pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data +#endif + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +size_t get_sync_flag_size(size_t num_m_blocks, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(int) * batch_size * num_heads * num_m_blocks; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_lean_dispatch(params, stream); + }); + }); +} + +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t max_seqlen_q, size_t max_seqlen_k, + size_t num_heads, size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int block_m = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 64); + const int num_m_blocks = (max_seqlen_q + block_m - 1) / block_m; + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + if (max_seqlen_q == 1) { + is_causal = false; + } + + max_seqlen_q = max_seqlen_q * num_heads / num_heads_k; + +#if defined(DEBUG_LEAN_ATTENTION) + printf("block_n: %d\n", block_n); + printf("block_m: %d\n", block_m); + printf("num_m_blocks: %d\n", num_m_blocks); + printf("num_n_blocks: %d\n", num_n_blocks); + printf("max_seqlen_q: %lu\n", max_seqlen_q); + printf("max_seqlen_k: %lu\n", max_seqlen_k); + printf("is_causal: %d\n", is_causal); + printf("num_heads: %lu\n", num_heads); + printf("num_heads_k: %lu\n", num_heads_k); +#endif + + size_t tiles_per_head = 0; + if (is_causal) { + // Prefill - Causal + for (int i = 0; i < num_m_blocks; i++) { + tiles_per_head += (((i + 1) * block_m) + block_n - 1) / block_n; + } + } else { + // Decode or Not Causal + // Tiles per head is the number of blocks in the first block + tiles_per_head = num_m_blocks * num_n_blocks; + } + size_t total_tiles = tiles_per_head * batch_size * num_heads_k; + + // StreamK Lean has as many threadblocks as SMs + // This should be a function of tile size and number of scratchpad space + + // We want at least two tiles per CTA to be efficient + // And then 2 CTAs per SM + size_t lean_griddimz = num_SMs * 2; + if (total_tiles <= 2 * 2 * num_SMs) { + lean_griddimz = std::min((total_tiles + 1) / 2, (32 * total_tiles + num_n_blocks - 1) / num_n_blocks); + // params.lean_griddimz = num_m_blocks * batch_size * num_heads; + } else { + // Max split of 64 per block is allowed, so we conservatively set it to 32 + // to account for ceil + lean_griddimz = std::min(2 * num_SMs, 32 * num_heads_k * batch_size * num_m_blocks); + } + size_t max_tiles_per_tb = (total_tiles + lean_griddimz - 1) / lean_griddimz; + // Find max number of splits + size_t num_splits = 0; + if (total_tiles % lean_griddimz == 0) { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 2) / (max_tiles_per_tb)); + } else { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 3) / (max_tiles_per_tb - 1)); + } + size_t high_load_tbs = total_tiles - ((max_tiles_per_tb - 1) * lean_griddimz); + +#if defined(DEBUG_LEAN_ATTENTION) + printf("Causal: %d params.tiles_per_head : %lu\n", is_causal, tiles_per_head); + printf("num_splits = %lu\n", num_splits); + printf("total_tiles = %lu\n", total_tiles); + printf("lean_griddimz = %lu\n", lean_griddimz); + printf("max_tiles_per_tb = %lu\n", max_tiles_per_tb); + printf("high_load_tbs = %lu\n", high_load_tbs); +#endif + + if (num_splits > 1) { + size_t softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + const size_t head_size_rounded = round_multiple(head_size, 32); + size_t out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q, head_size_rounded); + size_t sync_flag_bytes = get_sync_flag_size(num_m_blocks, batch_size, num_heads_k); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes, sync_flag_bytes, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } else { + return {0, 0, 0, 0, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } +} + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size == 64 || head_size == 128) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + int grid_dimz, + int max_tiles_per_tb, + int high_load_tbs, + int tiles_per_head, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag, + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv, + int max_num_blocks_per_seq, + int page_block_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + const bool paged_KV = block_table != nullptr; + +#if defined(DEBUG_LEAN_ATTENTION) + printf( + "batch_size: %d num_heads %d num_heads_k %d head_size %d seqlen_q %d seqlen_k %d seqlen_k_new %d " + "softmax_scale %f is_causal %d is_bf16 %d past_bsnh %d num_splits %d grid_dimz %d max_tiles_per_tb %d " + "high_load_tbs %d tiles_per_head %d local_window_size %d is_rotary_interleaved %d is_packed_qkv %d " + "max_num_blocks_per_seq %d page_block_size %d\n", + batch_size, num_heads, num_heads_k, head_size, seqlen_q, seqlen_k, seqlen_k_new, + softmax_scale, is_causal, is_bf16, past_bsnh, num_splits, grid_dimz, max_tiles_per_tb, + high_load_tbs, tiles_per_head, local_window_size, is_rotary_interleaved, is_packed_qkv, + max_num_blocks_per_seq, page_block_size); +#endif + + // Lean attention treats decode as non-causal + if (seqlen_q == 1) { + is_causal = false; + } + + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + if (seqlenq_ngroups_swapped) { + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads_k * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + } else { + params.q_batch_stride = seqlen_q * num_heads_k * head_size; + } + params.q_row_stride = head_size; + params.q_head_stride = seqlen_q * head_size; + params.o_row_stride = head_size; + params.o_head_stride = seqlen_q * head_size; + params.o_batch_stride = seqlen_q * num_heads_k * head_size; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = rotary_dim; + } + + params.num_splits = num_splits; + params.lean_griddimz = grid_dimz; + params.max_tiles_per_tb = max_tiles_per_tb; + params.high_load_tbs = high_load_tbs; + params.tiles_per_head = tiles_per_head; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + params.sync_flag = sync_flag; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + params.alibi_slopes_ptr = nullptr; + if (paged_KV) { + params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table_batch_stride = max_num_blocks_per_seq; + // params.num_blocks = num_blocks; + params.page_block_size = page_block_size; + params.k_batch_stride = page_block_size * num_heads_k * head_size; + params.v_batch_stride = page_block_size * num_heads_k * head_size; + } else { + params.block_table = nullptr; + params.block_table_batch_stride = 0; + // params.num_blocks = 0; + params.page_block_size = 1; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream); + + return Status::OK(); +} + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h new file mode 100644 index 0000000000000..3b9bd1c24f08c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if USE_LEAN_ATTENTION + +#include "core/providers/cuda/cuda_common.h" +#include + +namespace onnxruntime { +namespace lean { + +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + int grid_dimz = 0, + int max_tiles_per_tb = 0, + int high_load_tbs = 0, + int tiles_per_head = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag = nullptr, + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false, + int max_num_blocks_per_seq = 0, + int page_block_size = 1); + +size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); + +std::tuple +get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, + size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal); + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu new file mode 100644 index 0000000000000..cfcacbabb3cb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu new file mode 100644 index 0000000000000..44c870f6ab35b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h new file mode 100644 index 0000000000000..bd54b404420e5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h @@ -0,0 +1,1068 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "contrib_ops/cuda/bert/lean_attention/block_info.h" +#include "contrib_ops/cuda/bert/lean_attention/kernel_traits.h" +#include "contrib_ops/cuda/bert/lean_attention/utils.h" +#include "contrib_ops/cuda/bert/lean_attention/softmax.h" +#include "contrib_ops/cuda/bert/lean_attention/mask.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialized for Prefill +template +inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const int cta_id, int start_tile_gid, int start_tile_hid, int num_tiles, const int num_tiles_per_head) { +#if defined(DEBUG_LEAN_ATTENTION) + // Timing + auto kernel_start = clock64(); + long long int comp1_duration = 0; + long long int comp2_duration = 0; + long long int epilogue_duration = 0; + long long int prologue_duration = 0; + long long int epil1_duration = 0; + long long int epil2_duration = 0; + long long int epil3_duration = 0; + + const int tracing_block = 0; +#endif + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = typename Kernel_traits::GmemTiledCopyO; + using GmemTiledCopyOaccum = typename Kernel_traits::GmemTiledCopyOaccum; + + const int num_m_blocks_per_head = (params.seqlen_q + kBlockM - 1) / kBlockM; + + // // This is the solution to the summation series (n+1)(n+2)/2 = start_tile_hid + 1 + // int cur_m_block = Is_causal ? (int)ceilf((sqrtf(9 + (8*start_tile_hid)) - 3) / 2) : start_tile_hid/num_tiles_per_head; + float block_scale = (float)kBlockM / (float)kBlockN; + int cur_m_block = Is_causal ? kBlockM > kBlockN ? (int)ceilf((sqrtf(1 + (8 * start_tile_hid + 8) / block_scale) - 3) / 2) + // : (int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * (1 / block_scale) + (int)((start_tile_hid - (1 / block_scale) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1) / 2)) / ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1)) + : static_cast((-1 + sqrt(1 + 8 * start_tile_hid * block_scale)) / (2 * block_scale)) + : start_tile_hid / num_tiles_per_head; + int num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + int cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + int cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + int num_tiles_left = num_tiles; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Debugging block = %d\n", tracing_block); + printf("kBlockM = %d\n", kBlockM); + printf("kBlockN = %d\n", kBlockN); + printf("kHeadDim = %d\n", kHeadDim); + printf("kNWarps = %d\n", kNWarps); + printf("IsEvenMN = %d\n", Is_even_MN); + printf("block_scale = %f\n", block_scale); + printf("seq_len_q -change = %d\n", params.seqlen_q); + printf("seq_len_k = %d\n", params.seqlen_k); + printf("q_batch_stride = %ld\n", params.q_batch_stride); + printf("q_head_stride = %ld\n", params.q_head_stride); + printf("q_row_stride = %ld\n", params.q_row_stride); + printf("k_batch_stride = %ld\n", params.k_batch_stride); + printf("k_head_stride = %ld\n", params.k_head_stride); + printf("k_row_stride = %ld\n", params.k_row_stride); + printf("v_row_stride = %ld\n", params.v_row_stride); + printf("o_row_stride = %ld\n", params.o_row_stride); + printf("start_m_block = %d\n", cur_m_block); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("num_m_blocks_per_head = %d\n", num_m_blocks_per_head); + printf("cur_m_block = %d\n", cur_m_block); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("Total tiles = %d\n", num_tiles); + } +#endif + + // Prologue + int n_tile_min = kBlockM > kBlockN ? start_tile_hid - (block_scale * cur_m_block * (cur_m_block + 1) / 2) + : start_tile_hid - (int)(((int)floorf(cur_m_block * block_scale) * ((int)floorf(cur_m_block * block_scale) + 1) / 2) / block_scale) - ((cur_m_block % int(1 / block_scale)) * (floorf(cur_m_block * block_scale) + 1)); + int n_tile = n_tile_min + num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : n_tile_min + num_tiles_left - 1; + + index_t row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + index_t row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // // Start from the last block of first head + // lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + // params.seqlen_q - cur_m_block * kBlockM); + + // // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + // lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + // params.seqlen_k - n_tile * kBlockN); + // cute::cp_async_fence(); + + index_t row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + // Tiled Matrix Multiply + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling - Can be moved + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_tile_min = %d\n", n_tile_min); + printf("n_tile = %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } + + int num_blocks = 0; +#endif + + for (; num_tiles_left > 0;) { +#if defined(DEBUG_LEAN_ATTENTION) + num_blocks += 1; + auto prologue_start = clock64(); +#endif + + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + // Scheduling Policy - below + + // Calculate split ID + int block_start_gid = start_tile_gid - n_tile_min; + int cta_id_block_start = block_start_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_start_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_start_gid / params.max_tiles_per_tb; + int n_split_idx = cta_id - cta_id_block_start; + + // Check host/ + int host_cta = 0; + int total_splits = 1; + if (n_tile_min == 0) { + host_cta = 1; + int block_end_gid = start_tile_gid + num_tiles_in_block - 1; + int cta_id_block_end = block_end_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_end_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_end_gid / params.max_tiles_per_tb; + total_splits = cta_id_block_end - cta_id + 1; + } + + int end_cta = 0; + if (n_tile == num_tiles_in_block - 1) { + end_cta = 1; + } + + start_tile_gid += n_tile - n_tile_min + 1; + start_tile_hid += n_tile - n_tile_min + 1; + if (start_tile_hid >= num_tiles_per_head) { + // Next head + start_tile_hid = 0; + } + num_tiles_left -= n_tile - n_tile_min + 1; + + const BlockInfo binfo(params, cur_bidb); + // This is a hack, we really need to handle this outside the kernel + // But can't figure out a way to get actual seqlen_k in host-side code. + int max_actual_tiles = (binfo.actual_seqlen_k + kBlockN - 1) / kBlockN; + int num_actual_tiles_in_block = Is_causal ? std::max(max_actual_tiles, (int)ceilf(block_scale * (cur_m_block + 1))) : max_actual_tiles; + if (n_tile >= max_actual_tiles) { + tKgK.data() = tKgK.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.k_row_stride)); + tVgV.data() = tVgV.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.v_row_stride)); + n_tile = max_actual_tiles - 1; + } + if constexpr (Append_KV) { + if (end_cta) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.knew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.vnew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); + } +#endif + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_tile_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d n_tile_min %d n_tile %d n_block_copy_min %d\n", blockIdx.z, n_tile_min, n_tile, n_block_copy_min); + } +#endif + for (int n_block = n_tile; n_block >= n_block_copy_min; n_block--) { + lean::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + + lean::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + } + lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - cur_m_block * kBlockM); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("##### CTA : %d\n", blockIdx.z); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("cur_m_block = %d\n", cur_m_block); + printf("seqlen_k_cache = %d\n", binfo.seqlen_k_cache); + printf("actual_seqlen_q = %d\n", binfo.actual_seqlen_q); + printf("actual_seqlen_k = %d\n", binfo.actual_seqlen_k); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("n_tile(new) = %d\n", n_tile); + printf("n_tile_min = %d\n", n_tile_min); + printf("host_cta = %d\n", host_cta); + printf("end_cta = %d\n", end_cta); + printf("n_split_idx = %d\n", n_split_idx); + printf("total_splits = %d\n", total_splits); + printf("\n#### For next block:\n"); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("num_tiles_left = %d\n", num_tiles_left); + printf("\n"); + } +#endif + + // All scheduling policy decisions should be made above this line + clear(acc_o); + + lean::Softmax<2 * size<1>(acc_o)> softmax; + + lean::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0f); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + prologue_duration += clock64() - prologue_start; + auto compute_start = clock64(); +#endif + + // Clear the smem tiles to account for predicated off loads + lean::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - Svalue: acc_s[0] = %f\n", acc_s(0)); + } +#endif + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + print(tVsV); + } + // __syncthreads(); +#endif + + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - PValue[0] = %f\n", acc_s(0)); + } +#endif + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = lean::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - AfterPV[0] = %f\n", acc_o(0)); + } +#endif + + n_tile -= 1; + +#if defined(DEBUG_LEAN_ATTENTION) + comp1_duration += clock64() - compute_start; + compute_start = clock64(); +#endif + + // These are the iterations where we don't need masking on S + for (; n_tile >= n_tile_min; --n_tile) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + + lean::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Svalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + + lean::cp_async_wait<0>(); + __syncthreads(); + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Pvalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + Tensor rP = lean::convert_type(acc_s); + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d AfterPV[0] = %f\n", n_tile, acc_o(0)); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Epilogue + comp2_duration += clock64() - compute_start; + auto epilogue_start = clock64(); +#endif + + if (host_cta && end_cta) { +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("lse[0] = %f\n", lse(0)); + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = lean::convert_type(acc_o); + + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = cur_bidb * params.o_batch_stride + + cur_m_block * kBlockM * params.o_row_stride + cur_bidh * params.o_head_stride; + + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("tOpO[0] = %d\n", tOpO(0)); + printf("tOrO[0] = %f\n", tOrO(0)); + } +#endif + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + // epil1_duration += clock64() - epilogue_start; + } else if (!host_cta) { + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = (((index_t)(n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_split_idx = %d\n", n_split_idx); + // printf("row_offset_o = %" PRId64 "\n", row_offset_o); + printf("row_offset_oaccum = %" PRId64 "\n", row_offset_oaccum); + printf("row_offset_lseaccum = %" PRId64 "\n", row_offset_lseaccum); + } +#endif + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - cur_m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + + __threadfence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d Writing Flag %d\n", blockIdx.z, (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block); + } +#endif + + atomicAdd(reinterpret_cast(params.sync_flag) + (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block, 1); + +#if defined(DEBUG_LEAN_ATTENTION) + epil2_duration += clock64() - epilogue_start; +#endif + } else { + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + //////////////////////////////////////////////////////////////////////////////// +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Before LSE acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After LSE acc_o[0] = %f\n", acc_o(0)); + printf("lse[0] = %f\n", lse(0)); + } +#endif + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + // We move to SMEM and back because we need equal distribution of + // accum registers. Initially only threads 0,4,8,etc have oaccum values. + // So, first move them to SMEM. + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = ((cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Block %d row_offset_oaccum = %" PRId64 "\n", blockIdx.z, row_offset_oaccum); + printf("Block %d row_offset_lseaccum = %" PRId64 "\n", blockIdx.z, row_offset_lseaccum); + } +#endif + + // GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + // auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + // Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOgOaccumReg = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccumReg)); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("First split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + + __syncthreads(); + + // Bring the oaccum back from SMEM to registers + // Now all threads have oaccum values equaly distributed. + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + ///////////////////////////////////////////////////////////////////////////// + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + Tensor sLSE = make_tensor(sV.data(), Shape, Int>{}); // (SMEM_M,SMEM_N) + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int col = get<0>(taccOcO_row(mi)); + if (col < params.seqlen_q - cur_m_block * kBlockM) { + sLSE(0, col) = lse(mi); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("threadIdx.x %d col %d mi%d slSE %f\n", threadIdx.x, col, mi, lse(mi)); + } +#endif + } + } + } + + // Synchronize here to make sure all atomics are visible to all threads. + // Not exactly sure why we need this, but it seems to be necessary. + __threadfence(); + while (atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0) < (total_splits - 1) * kNThreads) { + __threadfence(); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Waiting Block: %d target-value: %d\n", blockIdx.z, (total_splits - 1) * kNThreads); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Print sync flag value + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + int32_t sync_flag = atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0); + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Sync flag value: %d\n", sync_flag); + } + } +#endif + + Tensor gLSEaccumRead = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // R + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; // R + +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + // We skip the first row = 0, as we already populated it in shared memory. + ElementAccum lse = (row > 0 && row < total_splits && col < params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum) ? gLSEaccumRead(row, col) : -std::numeric_limits::infinity(); + if (row > 0 && row < kMaxSplits) { + sLSE(row, col) = lse; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse %f\n", threadIdx.x, l, row, col, lse); + } +#endif + } + } + __syncthreads(); // For all LSEs to reach shared memory + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("kNLsePerThread %d kRowsPerLoadLSE %d kRowsPerLoadTranspose %d\n", kNLsePerThread, kRowsPerLoadLSE, kRowsPerLoadTranspose); + } +#endif + + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row, col) : -std::numeric_limits::infinity(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f\n", threadIdx.x, l, row, col, lse_accum(l)); + } +#endif + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) + ? std::numeric_limits::infinity() + : logf(lse_sum) + lse_max; +// if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < total_splits && col < kBlockM) { + sLSE(row, col) = expf(lse_accum(l) - lse_logsum); + ElementAccum lse_scale = sLSE(row, col); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f lse_logsum %f sLSE %f\n", threadIdx.x, l, row, col, lse_accum(l), lse_logsum, lse_scale); + } +#endif + } + } + + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + + // Sync here for sLSE stores to go through + __syncthreads(); +// First reduce self Oaccum +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(0, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, 0, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After First Split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + // Load Oaccum in then scale and accumulate to O + // Here m is each row of 0accum along token dimension + // k is + for (int split = 1; split < total_splits; ++split) { + lean::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(split, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, split, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + } + + Tensor r1 = lean::convert_type(tOrO); + +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(r1); ++m) { + const int idx = cur_m_block * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.seqlen_q) { + // The index to the rows of Q + const int row = idx; + auto o_ptr = reinterpret_cast(params.o_ptr) + cur_bidb * params.o_batch_stride + cur_bidh * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(r1); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(r1))::value>>{}, Stride<_1>{}); + copy(r1(_, m, k), gO); + } + } + } + } +#if defined(DEBUG_LEAN_ATTENTION) + epil3_duration += clock64() - epilogue_start; +#endif + } + + if (num_tiles_left) { + // We can probably do better than this + // We first decrement the pointers back to starting. + // We can probably just use q_ptr and k_ptr directly. But can't figure out how to do it. + // Without disturbing the gQ, gK, gV tensor pointer CUTE objects. + tQgQ.data() = tQgQ.data() + (-int(row_offset_q)); + tKgK.data() = tKgK.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.k_row_stride - row_offset_k); + tVgV.data() = tVgV.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.v_row_stride - row_offset_v); + cur_m_block = cur_m_block + 1 >= num_m_blocks_per_head ? 0 : cur_m_block + 1; + num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + n_tile = num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : num_tiles_left - 1; + n_tile_min = 0; + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + + tQgQ.data() = tQgQ.data() + row_offset_q; + tKgK.data() = tKgK.data() + row_offset_k; + tVgV.data() = tVgV.data() + row_offset_v; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("#### Ready for next block:\n"); + printf("next_block %d\n", cur_m_block); + printf("n_tile %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + epilogue_duration += clock64() - epilogue_start; +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0) { + uint smid; + asm("mov.u32 %0, %smid;" : "=r"(smid)); + printf("%d %d %d %d %lld %lld %lld %lld %lld %lld %lld %lld\n", + blockIdx.z, num_blocks, smid, cta_id, clock64() - kernel_start, prologue_duration, comp1_duration, + comp2_duration, epilogue_duration, epil1_duration, epil2_duration, epil3_duration); + } +#endif +} + +template +inline __device__ void lean_compute_attn(const Params& params) { + // const int cta_id = blockIdx.z < 54 ? 4*blockIdx.z : blockIdx.z < 108 ? 4*(blockIdx.z % 54) + 2 : blockIdx.z < 162 ? 4*(blockIdx.z % 108) + 1 : 4*(blockIdx.z % 162) + 3; + const int cta_id = blockIdx.z; + int start_tile_gid = cta_id < params.high_load_tbs ? params.max_tiles_per_tb * cta_id : (params.max_tiles_per_tb - 1) * cta_id + params.high_load_tbs; + int start_tile_hid = start_tile_gid % params.tiles_per_head; + int num_tiles = cta_id < params.high_load_tbs ? params.max_tiles_per_tb : params.max_tiles_per_tb - 1; + + lean::lean_compute_attn_impl_ver3(params, cta_id, start_tile_gid, start_tile_hid, num_tiles, params.tiles_per_head); +} + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h new file mode 100644 index 0000000000000..fcccb54ebf4e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h @@ -0,0 +1,73 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h" + +namespace onnxruntime { +namespace lean { + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ + template \ + __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(lean_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, int kMaxSplits, bool Append_KV) { +#if defined(ARCH_SUPPORTS_FLASH) + lean::lean_compute_attn(params); +#else + FLASH_UNSUPPORTED_ARCH +#endif +} + +template +void run_lean_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + dim3 grid(1, 1, params.lean_griddimz); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + MAXSPLIT_SWITCH(params.num_splits, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { + auto kernel = &lean_fwd_kernel < Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, kMaxSplits, Append_KV_Const > ; + if (2 * smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 2 * smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + // This should be modified according to optimal lean tile size + constexpr static int kBlockM = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 64); + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_lean_fwd>(params, stream); +} + +} // namespace lean +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h new file mode 100644 index 0000000000000..2d33418d69667 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h @@ -0,0 +1,209 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once +#include +#include + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -std::numeric_limits::infinity(); + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope = 0.f) + : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor& tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), lean::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + } + } + } + } else { +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); + } + } + } + } + } + } + } + } + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h new file mode 100644 index 0000000000000..0b6ffb3f1985a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h @@ -0,0 +1,200 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once +#include +#include + +#include + +#include + +#include "contrib_ops/cuda/bert/lean_attention/utils.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -std::numeric_limits::infinity() + ? 0.f + : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { +// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - +// max * log_2(e)) This allows the compiler to use the ffma +// instruction instead of fadd and fmul separately. +// The following macro will disable the use of fma. +// See: https://github.com/pytorch/pytorch/issues/121558 for more details +// This macro is set in PyTorch and not FlashAttention +#ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); +#else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); +#endif + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), lean::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + lean::template reduce_max(scores, row_max); + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + lean::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + lean::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + lean::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, float rp_dropout = 1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + // if (threadIdx.x == 0 && blockIdx.z == 0) { + // printf("sum: %f, inv_sum: %f\n", sum, inv_sum); + // printf("mi %d row_max %f softmax_scale %f\n", mi, row_max(mi), softmax_scale); + // } + lse(mi) = (sum == 0.f || sum != sum) + ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) + : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + return lse; + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h new file mode 100644 index 0000000000000..7873f67471d5d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h @@ -0,0 +1,109 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT +#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI +#define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K +#define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else +#define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL +#define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MAXSPLIT_SWITCH(MAXSPLITS, ...) \ + [&] { \ + if (MAXSPLITS <= 2) { \ + constexpr static int kMaxSplits = 2; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 4) { \ + constexpr static int kMaxSplits = 4; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 8) { \ + constexpr static int kMaxSplits = 8; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 16) { \ + constexpr static int kMaxSplits = 16; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 32) { \ + constexpr static int kMaxSplits = 32; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 64) { \ + constexpr static int kMaxSplits = 64; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h new file mode 100644 index 0000000000000..c76849686d539 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h @@ -0,0 +1,411 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template <> +struct Allreduce<1> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = lean::convert_type(tensor); + lean::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index c9c66b73b3e9d..a8f94304f8141 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -396,7 +396,7 @@ Status LaunchLongformerSoftmaxKernel( cudaDataType_t Atype; cudaDataType_t Btype; cudaDataType_t Ctype; - cudaDataType_t resultType; + cublasComputeType_t resultType; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; __half one_fp16, zero_fp16; @@ -412,7 +412,7 @@ Status LaunchLongformerSoftmaxKernel( Atype = CUDA_R_16F; Btype = CUDA_R_16F; Ctype = CUDA_R_16F; - resultType = CUDA_R_16F; + resultType = CUBLAS_COMPUTE_16F; algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; } else { one_fp32 = 1.f; @@ -423,7 +423,7 @@ Status LaunchLongformerSoftmaxKernel( Atype = CUDA_R_32F; Btype = CUDA_R_32F; Ctype = CUDA_R_32F; - resultType = CUDA_R_32F; + resultType = CUBLAS_COMPUTE_32F; } // Strided batch matrix multiply diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu index 2c251246267b7..9f92faac25b73 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu @@ -221,7 +221,7 @@ Status LaunchLongformerSoftmaxSimpleKernel( cudaDataType_t Atype; cudaDataType_t Btype; cudaDataType_t Ctype; - cudaDataType_t resultType; + cublasComputeType_t resultType; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; __half one_fp16, zero_fp16; @@ -237,7 +237,7 @@ Status LaunchLongformerSoftmaxSimpleKernel( Atype = CUDA_R_16F; Btype = CUDA_R_16F; Ctype = CUDA_R_16F; - resultType = CUDA_R_16F; + resultType = CUBLAS_COMPUTE_16F; algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; } else { one_fp32 = 1.f; @@ -248,7 +248,7 @@ Status LaunchLongformerSoftmaxSimpleKernel( Atype = CUDA_R_32F; Btype = CUDA_R_32F; Ctype = CUDA_R_32F; - resultType = CUDA_R_32F; + resultType = CUBLAS_COMPUTE_32F; } // Strided batch matrix multiply diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 9c558900d1fdb..e2587d172af94 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -54,6 +55,10 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); +#if USE_LEAN_ATTENTION + enable_lean_attention_ = sizeof(T) == 2 && kernel_options_->UseLeanAttention(); +#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); @@ -151,8 +156,64 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; +#endif + +#if USE_LEAN_ATTENTION + // Lean attention only supports token-generation phase with sequence_length == 1. + bool use_lean_attention = enable_lean_attention_ && + parameters.sequence_length == 1 && + parameters.past_sequence_length > 0 && + nullptr == attention_bias && + nullptr == key_padding_mask && + parameters.head_size == parameters.v_head_size && + onnxruntime::lean::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + + size_t sync_flag_bytes = 0; + if (use_lean_attention) { + softmax_lse_bytes = onnxruntime::lean::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + + auto [num_splits, slse_accum_bytes, o_accum_bytes, sflag_bytes, griddimz, max_tiles_tb, hload_tbs, tiles_per_head] = onnxruntime::lean::get_num_splits_and_buffer_sizes( + parameters.batch_size, + parameters.sequence_length, + parameters.total_sequence_length, + parameters.num_heads, // q heads + parameters.num_heads, // kv heads + parameters.head_size, + device_prop.multiProcessorCount, + parameters.is_unidirectional); + + data.num_splits = static_cast(num_splits); + data.grid_dim_z = static_cast(griddimz); + data.max_tiles_per_tb = static_cast(max_tiles_tb); + data.high_load_tbs = static_cast(hload_tbs); + data.tiles_per_head = static_cast(tiles_per_head); + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + sync_flag_bytes = sflag_bytes; + kernel_type = AttentionKernelType::AttentionKernel_LeanAttention; + } + + auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, context->GetComputeStream()); + data.lean_sync_flag = reinterpret_cast(lean_sync_flag_buffer.get()); +#else + constexpr bool use_lean_attention = false; +#endif + #if USE_FLASH_ATTENTION - bool use_flash_attention = !disable_flash_attention_ && + bool use_flash_attention = kernel_type == AttentionKernelType::AttentionKernel_Default && + !disable_flash_attention_ && nullptr == attention_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -165,25 +226,35 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } + // Allocate buffers - size_t softmax_lse_accum_bytes = 0; - size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; kernel_type = AttentionKernelType::AttentionKernel_FlashAttention; } - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + if (use_flash_attention || use_lean_attention) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } #endif bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE || @@ -284,8 +355,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { kernel_type = AttentionKernelType::AttentionKernel_Unfused; } - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); @@ -303,6 +372,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; + data.use_lean_attention = use_lean_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.kernel_type = kernel_type; data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); @@ -331,6 +401,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_sdpa, @@ -342,16 +413,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.workspace_bytes = workspace_bytes; data.allow_debug_info = kernel_options_->AllowDebugInfo(); - if (softmax_lse_accum_buffer != nullptr) { - data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); - } - if (out_accum_buffer != nullptr) { - data.out_accum = reinterpret_cast(out_accum_buffer.get()); - } if (data.allow_debug_info) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_lean_attention = use_lean_attention; debug_info.use_cudnn_flash_attention = use_cudnn_sdpa; debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; debug_info.use_efficient_attention = use_memory_efficient_attention; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 8edc1d0e6ac06..b093b226c50b0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,9 @@ class MultiHeadAttention final : public CudaKernel { bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; bool disable_flash_attention_; +#if USE_LEAN_ATTENTION + bool enable_lean_attention_; +#endif bool disable_memory_efficient_attention_; bool enable_cudnn_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu index 8a04ede231a27..ab809d12a89ad 100644 --- a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu @@ -6,7 +6,7 @@ Licensed under the MIT License. /* Kernel implementation for blocking repeated n-grams. */ - +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/ngram_repeat_block_impl.h" @@ -48,7 +48,7 @@ __global__ void banRepeatedTokens(const int64_t* __restrict__ tokens, } if (is_banned == true) { auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; - lprobs[lprob_start + token_to_be_banned] = -INFINITY; + lprobs[lprob_start + token_to_be_banned] = -std::numeric_limits::infinity(); } } diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 279df73ee3d45..0554cc34933f1 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -348,7 +348,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { } Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); auto cudnnHandle = this->GetCudnnHandle(context); ORT_RETURN_IF_ERROR(UpdateState(context, true)); if (s_.Y->Shape().Size() == 0) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 1b774b163888f..33cd906508bcf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -179,6 +179,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_fused_cross_attention = false; constexpr bool use_memory_efficient_attention = false; constexpr bool use_flash_attention = false; + constexpr bool use_lean_attention = false; constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, @@ -190,6 +191,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5ac10f6321e63..44be2ef2375ee 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -60,7 +60,7 @@ struct TopK { __device__ __forceinline__ void Init() { for (int i = 0; i < max_k; i++) { key[i] = -1; - value[i] = NumericLimits::Min(); + value[i] = NumericLimits::Lowest(); } } }; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e047bd948434d..4e65336665bf7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu index 68a2e16482af9..b2969194ff400 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu @@ -5,6 +5,7 @@ #include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/cu_inc/common.cuh" @@ -19,7 +20,10 @@ struct TopOne { int32_t key; T value; - __device__ __host__ __forceinline__ TopOne(int32_t key = -1, T value = NumericLimits::Min()) : key(key), value(value) { + __device__ __host__ __forceinline__ TopOne() : key(-1), value(NumericLimits::Lowest()) { + } + + __device__ __host__ __forceinline__ TopOne(int32_t key, T value) : key(key), value(value) { } __device__ __forceinline__ void Reduce(int32_t k, T v) { diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h index 7553883a2478d..dff8663133c31 100644 --- a/onnxruntime/contrib_ops/js/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#include "contrib_ops/cpu/bert/gqa_attention_base.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { @@ -11,31 +11,29 @@ namespace js { using onnxruntime::js::JsKernel; -class GroupQueryAttention : public JsKernel { +class GroupQueryAttention : public JsKernel, GQAAttentionBase { public: explicit GroupQueryAttention(const OpKernelInfo& info) - : JsKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - scale_ = info.GetAttrOrDefault("scale", 0.0f); + : JsKernel(info), GQAAttentionBase(info, false) { JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({ "numHeads" : $1, "kvNumHeads" : $2, "scale" : $3, + "softcap" : $4, + "doRotary" : $5, + "rotaryInterleaved" : $6, + "smoothSoftmax" : $7, + "localWindowSize" : $8 }), static_cast(num_heads_), static_cast(kv_num_heads_), - static_cast(scale_)); + static_cast(scale_), + static_cast(softcap_), + static_cast(do_rotary_), + static_cast(rotary_interleaved_), + static_cast(use_smooth_softmax_), + static_cast(local_window_size_)); } - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // number of k and v heads - float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size) }; } // namespace js diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu index 473ab8dd3ce4d..b40fc2bf0eef8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -84,7 +84,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(kPresentOutputIndex, present_shape); auto stream = Stream(context); - rocblas_handle rocblas = GetRocblasHandle(context); + hipblasHandle_t hipblas = GetHipblasHandle(context); using HipT = typename ToHipType::MappedType; using QkvProjectGeneric = GemmPermuteGenericPipeline; @@ -113,7 +113,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& params = gemm_permute_params; params.tuning_ctx = GetTuningContext(); params.stream = context->GetComputeStream(); - params.handle = rocblas; + params.handle = hipblas; params.attention = &attn; params.device_prop = &device_prop; @@ -179,7 +179,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& params = gemm_softmax_gemm_permute_params; params.tuning_ctx = GetTuningContext(); params.stream = context->GetComputeStream(); - params.handle = rocblas; + params.handle = hipblas; params.attention = &attn; params.device_prop = &device_prop; // FIXME: the params.scale seems to be different from AttentionParameters::scale; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b94971ffd44d5..270a8e51daf88 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -182,7 +182,7 @@ Status DecoderQkvToContext( const hipDeviceProp_t& prop, RocmTuningContext* tuning_ctx, Stream* ort_stream, - rocblas_handle& rocblas, + hipblasHandle_t& hipblas, const size_t element_size, const int batch_size, const int sequence_length, @@ -284,7 +284,7 @@ Status DecoderQkvToContext( const int strideB = sequence_length * head_size; if (use_past && static_kv) { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, rocblas, + tuning_ctx, ort_stream, hipblas, blas::BlasOp::Trans, blas::BlasOp::NonTrans, kv_sequence_length, sequence_length, head_size, /*alpha=*/rsqrt_head_size, @@ -295,7 +295,7 @@ Status DecoderQkvToContext( BN)); } else { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, rocblas, + tuning_ctx, ort_stream, hipblas, blas::BlasOp::Trans, blas::BlasOp::NonTrans, kv_sequence_length, sequence_length, head_size, /*alpha=*/rsqrt_head_size, @@ -320,7 +320,7 @@ Status DecoderQkvToContext( // compute P*V (as V*P), and store in scratch3: BxNxSxH if (use_past && static_kv) { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, rocblas, + tuning_ctx, ort_stream, hipblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, head_size, sequence_length, kv_sequence_length, /*alpha=*/1.0f, @@ -331,7 +331,7 @@ Status DecoderQkvToContext( BN)); } else { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, rocblas, + tuning_ctx, ort_stream, hipblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, head_size, sequence_length, kv_sequence_length, /*alpha=*/1.0f, @@ -351,7 +351,7 @@ Status LaunchDecoderAttentionKernel( const hipDeviceProp_t& prop, RocmTuningContext* tuning_ctx, Stream* stream, - rocblas_handle& rocblas, + hipblasHandle_t& hipblas, const size_t element_size, const int batch_size, const int sequence_length, @@ -378,7 +378,7 @@ Status LaunchDecoderAttentionKernel( prop, tuning_ctx, stream, - rocblas, + hipblas, element_size, batch_size, sequence_length, @@ -405,7 +405,7 @@ Status LaunchDecoderAttentionKernel( prop, tuning_ctx, stream, - rocblas, + hipblas, element_size, batch_size, sequence_length, diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index d593bc0012826..6c2e36b596d32 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -4,7 +4,7 @@ #pragma once #include -#include +#include #include "contrib_ops/cpu/bert/attention_common.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/tunable/rocm_tunable.h" @@ -70,64 +70,59 @@ Status LaunchConcatTensorToTensor(hipStream_t stream, const half* tensor_add, half* tensor_out); -inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - rocblas_datatype a_type, - rocblas_int lda, - rocblas_stride stride_A, - const void* b, - rocblas_datatype b_type, - rocblas_int ldb, - rocblas_stride stride_b, - const void* beta, - void* c, - rocblas_datatype c_type, - rocblas_int ldc, - rocblas_stride stride_c, - rocblas_int batch_count, - rocblas_datatype compute_type, - rocblas_gemm_algo algo) { - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - c, // D = C - c_type, // D type = C type - ldc, // ldd = ldc - stride_c, // strideD = strideC - batch_count, // batch count - compute_type, - algo, - 0, 0); +inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + hipDataType a_type, + int lda, + hipblasStride stride_A, + const void* b, + hipDataType b_type, + int ldb, + hipblasStride stride_b, + const void* beta, + void* c, + hipDataType c_type, + int ldc, + hipblasStride stride_c, + int batch_count, + hipblasComputeType_t compute_type, + hipblasGemmAlgo_t algo) { + return hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, // m + n, // n + k, // k + alpha, // alpha + A, // A + a_type, // A type + lda, // lda + stride_A, // strideA + b, // B + b_type, // B type + ldb, // ldb + stride_b, // strideB + beta, // beta + c, // C + c_type, // C type + ldc, // ldc + stride_c, // strideC + batch_count, // batch count + compute_type, + algo); } // Compatible for CublasMathModeSetter -class CompatRocblasMathModeSetter { +class CompatHipblasMathModeSetter { public: - CompatRocblasMathModeSetter(const hipDeviceProp_t&, - rocblas_handle, + CompatHipblasMathModeSetter(const hipDeviceProp_t&, + hipblasHandle_t, int) { } }; diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh index 5401c850bc8f7..f7709e8242147 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh @@ -32,7 +32,7 @@ struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { return MakeString("M", m, "_N", n, "_K", k, "_B", batch); } - rocblas_handle handle; + hipblasHandle_t handle; const AttentionParameters* attention; const hipDeviceProp_t* device_prop; diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index e013f35e150c4..e190a6938dc6b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -388,7 +388,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { return {m, n, k, o, batch}; } - rocblas_handle handle; + hipblasHandle_t handle; const RocmAttentionParameters* attention; const hipDeviceProp_t* device_prop; diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h index d71c6d8440a44..0aff519d20e99 100644 --- a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -4,7 +4,7 @@ #pragma once #include -#include +#include #include "contrib_ops/cpu/bert/attention_common.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/tunable/rocm_tunable.h" @@ -17,7 +17,7 @@ Status LaunchDecoderAttentionKernel( const hipDeviceProp_t& prop, // Device Properties RocmTuningContext* tuning_ctx, // context for tuning Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle + hipblasHandle_t& hipblas, // hipblas handle const size_t element_size, // Element size of input tensor const int batch_size, // Batch size (B) const int sequence_length, // Sequence length (S) diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc index 1121e82a99d3f..fdb62d3a2aec5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -58,7 +58,7 @@ Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { using onnxruntime::rocm::tunable::blas::BlasOp; return blas::row_major::GemmFastGelu( - GetTuningContext(), ctx->GetComputeStream(), GetRocblasHandle(ctx), + GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), transa ? BlasOp::Trans : BlasOp::NonTrans, transb ? BlasOp::Trans : BlasOp::NonTrans, helper.M(), helper.N(), helper.K(), diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h index dd98b76153cc2..2b8a21b83f177 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -24,7 +24,7 @@ struct GemmFastGeluParams : OpParams { bool has_bias = (nullptr != bias) ? 0 : 1; return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); } - rocblas_handle handle; + hipblasHandle_t handle; BlasOp opa; BlasOp opb; int64_t m; diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h index c4b4e68ed6275..2d6a47269ac48 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -14,7 +14,7 @@ namespace blas { #define GEMMFASTGELU(T, ScalarT) \ common::Status GemmFastGelu( \ - RocmTuningContext* tuning_ctx, Stream* stream, rocblas_handle handle, \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ BlasOp opa, BlasOp opb, \ std::int64_t m, std::int64_t n, std::int64_t k, \ ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index b07f9214e340e..fe0d621f1d601 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -245,7 +245,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { GemmSoftmaxGemmPermuteParams params; params.tuning_ctx = GetTuningContext(); params.stream = context->GetComputeStream(); - params.handle = GetRocblasHandle(context); + params.handle = GetHipblasHandle(context); params.attention = &attn; params.device_prop = &device_prop; params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc index 63804f79a32fb..4f3be98d97f80 100644 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -144,7 +144,7 @@ class FusedConv : public onnxruntime::rocm::Conv { } Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); + std::lock_guard lock(Base::s_.mutex); ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); if (Base::s_.Y->Shape().Size() == 0) { @@ -342,7 +342,7 @@ class FusedConv : public onnxruntime::rocm::Conv { }; struct FusionPlanCache { - mutable OrtMutex mutex; + mutable std::mutex mutex; using HashKey = uint32_t; std::unordered_map cache_directory_; @@ -351,7 +351,7 @@ class FusedConv : public onnxruntime::rocm::Conv { FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, std::function factory) { - std::lock_guard lock(mutex); + std::lock_guard lock(mutex); auto iter = cache_directory_.find(key); if (iter == cache_directory_.end()) { cache_directory_[key].fusion = std::make_unique(); diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index b65841b359647..519c44f351e64 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -113,7 +113,7 @@ Status GemmFloat8::ComputeFp8Fp16Fp16( onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream(); - params.handle = GetRocblasHandle(ctx); + params.handle = GetHipblasHandle(ctx); params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; @@ -157,7 +157,7 @@ Status GemmFloat8::ComputeFp16Fp8Fp16( onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream(); - params.handle = GetRocblasHandle(ctx); + params.handle = GetHipblasHandle(ctx); params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh index 571936fc5f038..5cebb7576abf3 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -133,7 +133,7 @@ struct GemmFloat8Params : tunable::OpParams { return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); } - rocblas_handle handle; + hipblasHandle_t handle; BlasOp opa; BlasOp opb; int64_t m; diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc new file mode 100644 index 0000000000000..86dc959cf2e83 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -0,0 +1,459 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3); + ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) { + if (seqlen_k != nullptr) { + ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; + ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n"; + } else { + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; + } +} + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + if (seqlen_k_ != nullptr) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.N;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" + << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " }\n"; + } + + if (has_present_key_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, + const Tensor* seqlen_k) { + const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) + : parameters.scale_; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + constexpr int tile_size = 12; + const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let sequence_length = uniforms.sequence_length;\n" + << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str() + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + if (seqlen_k_) { + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" + << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" + << "}\n"; + } + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, + const Tensor* seqlen_k, bool is_first_prompt) { + const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); + int work_group_size = 64; + const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; + if (total_sequence_length_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k}; + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .CacheHint(work_group_size, is_first_prompt) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(num_heads)}, + {static_cast(past_sequence_length)}, + {static_cast(sequence_length)}, + {static_cast(total_sequence_length_comp)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.K;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" + << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" + << " }\n"; + } + + if (has_present_value_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + WebgpuAttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length, + const Tensor* seqlen_k) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_value = output_count > 1 && past_value != nullptr; + constexpr int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size_)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(parameters.v_hidden_size_ * parameters.n_reps)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + + const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h new file mode 100644 index 0000000000000..03279fffbc3ef --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; + const Tensor* seqlen_k_; + bool is_first_prompt_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h new file mode 100644 index 0000000000000..be80ade8b87d0 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +#include "contrib_ops/cpu/bert/attention_common.h" +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +struct WebgpuAttentionParameters { + explicit WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.kv_sequence_length), + past_sequence_length_(parameters.past_sequence_length), + total_sequence_length_(parameters.total_sequence_length), + max_sequence_length_(parameters.max_sequence_length), + input_hidden_size_(parameters.input_hidden_size), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.v_hidden_size), + v_head_size_(parameters.v_head_size), + num_heads_(parameters.num_heads), + is_unidirectional_(parameters.is_unidirectional), + past_present_share_buffer_(parameters.past_present_share_buffer), + do_rotary_(parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.mask_filter_value), + scale_(parameters.scale), + mask_type_(parameters.mask_type), + qkv_format_(parameters.qkv_format) { + } + + explicit WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.sequence_length), + past_sequence_length_(parameters.seqlen_past_kv_cache), + total_sequence_length_(parameters.total_sequence_length), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.kv_hidden_size), + v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), + num_heads_(parameters.num_heads), + do_rotary_(parameters.do_rotary), + scale_(parameters.scale), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), + kv_hidden_size_(parameters.kv_hidden_size), + kv_num_heads_(parameters.kv_num_heads), + num_splits_(parameters.num_splits), + rotary_dim_(parameters.rotary_dim), + is_packed_qkv_(parameters.is_packed_qkv), + is_subsequent_prompt_(parameters.is_subsequent_prompt), + is_first_prompt_(parameters.is_first_prompt), + rotary_interleaved_(parameters.rotary_interleaved), + use_smooth_softmax_(parameters.use_smooth_softmax), + softcap_(parameters.softcap), + zeros_count_(parameters.zeros_count), + zero_ptr_(parameters.zero_ptr), + n_reps(parameters.num_heads / parameters.kv_num_heads), + qkv_format_(parameters.qkv_format) { + } + + bool is_gqa_; + int batch_size_ = 0; + int sequence_length_ = 0; + int kv_sequence_length_ = 0; // input sequence length of K or V + int past_sequence_length_ = 0; // sequence length in past state of K or V + int total_sequence_length_ = 0; // total sequence length of K or V + int max_sequence_length_ = 0; // max sequence length from 4D mask + int input_hidden_size_ = 0; // first dimension of weights for input projection + int hidden_size_ = 0; // hidden size of Q or K + int head_size_ = 0; // hidden size per head of Q or K + int v_hidden_size_ = 0; // hidden size of V + int v_head_size_ = 0; // hidden size per head of V + int num_heads_ = 0; + int rotary_embedding_ = 0; + bool is_unidirectional_ = false; + bool past_present_share_buffer_ = false; + bool do_rotary_ = false; + bool broadcast_attn_bias_dim_0_ = false; + bool broadcast_attn_bias_dim_1_ = false; + float mask_filter_value_ = -10000.0f; + float scale_ = 0.0f; + bool use_tf32_ = false; + ; + // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters + // and not in onnxruntime::contrib::AttentionParameters + int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor + int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor + int kv_hidden_size_ = 0; + int kv_num_heads_ = 0; + int num_splits_ = 0; // number of splits for splitkv + int rotary_dim_ = 0; // rotary embedding dimension + int local_window_size_ = 0; + bool kv_share_buffer_ = false; + bool is_packed_qkv_ = false; + bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt_ = false; // indicates whether this is first decoding step + bool rotary_interleaved_ = false; + bool use_smooth_softmax_ = false; + float softcap_ = 0.0; + int zeros_count_ = 0; + ; + int* zero_ptr_ = nullptr; + // Computed values + int n_reps = 1; + AttentionMaskType mask_type_ = MASK_NONE; + AttentionQkvFormat qkv_format_ = UNKNOWN; +}; + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 0000000000000..a5cae7e7f6747 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "contrib_ops/webgpu/bert/fast_gelu.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); + + shader.AdditionalImplementation() << TanhImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; + if (Inputs().size() > 1) { + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + } + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = gsl::narrow(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = 0; + int bias_components = 1; + + if (bias != nullptr) { + bias_size = gsl::narrow(bias->Shape().Size()); + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + } + + FastGeluProgram program{bias_components}; + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 0000000000000..fa40d52bf301f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc new file mode 100644 index 0000000000000..31c8af9b4f922 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::group_query_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .MayInplace(3, 1) + .MayInplace(4, 2) + .InputMemoryType(OrtMemTypeCPUInput, 6), + GroupQueryAttention); + +Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* past_key = context.Input(3); + const Tensor* past_value = context.Input(4); + const Tensor* seqlen_k = context.Input(5); + const Tensor* total_seqlen_tensor = context.Input(6); + const Tensor* cos_cache = context.Input(7); + const Tensor* sin_cache = context.Input(8); + + GroupQueryAttentionParameters params; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶ms, + num_heads_, + kv_num_heads_, + seqlen_k, + total_seqlen_tensor, + scale_, + softcap_)); + WebgpuAttentionParameters parameters(params); + if (parameters.is_packed_qkv_) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); + } + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.hidden_size_); + Tensor* output = context.Output(0, output_shape); + std::vector present_dims{ + parameters.batch_size_, + kv_num_heads_, + parameters.seqlen_present_kv_cache_, + parameters.head_size_}; + std::vector present_kv_shape(present_dims); + Tensor* present_key = context.Output(1, present_kv_shape); + Tensor* present_value = context.Output(2, present_kv_shape); + parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); + } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &V)); + return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h new file mode 100644 index 0000000000000..04969dc778927 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class GroupQueryAttention final : public WebGpuKernel { + public: + GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + float softcap_; + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int local_window_size_; + + bool use_smooth_softmax_; + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc new file mode 100644 index 0000000000000..8997e8698d96d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc @@ -0,0 +1,36 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 0000000000000..424556c66bd9d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info), AttentionBase(info, false) { + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters params; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶ms, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + WebgpuAttentionParameters parameters(params); + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.v_hidden_size_); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size_, + parameters.num_heads_, + parameters.total_sequence_length_, + parameters.head_size_, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); + + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 0000000000000..d983236422c9e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/attention_base.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..bc8b7493fc916 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + // TODO: remove output_indices. + const auto& output_indices = shader.AddIndices("output_indices", false); + const auto interleaved_str = interleaved_ ? "true" : "false"; + shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n" + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" + << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" + << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + << " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" + << " }"; + + return Status::OK(); +} + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + scale_ = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching_ = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto input_shape = input->Shape(); + const auto* position_ids = context.Input(1); + const auto* cos_cache = context.Input(2); + const auto* sin_cache = context.Input(3); + auto* output = context.Output(0, input_shape); + + const auto batch_size = gsl::narrow(input->Shape()[0]); + const auto batch_stride = gsl::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow(global_shape[j]); + global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = gsl::narrow(global_shape.Size()); + RotaryEmbeddingProgram program{interleaved_}; + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..0d73b89fb62df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class RotaryEmbeddingProgram final : public Program { + public: + RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, + {"global_shape", ProgramUniformVariableDataType::Uint32}, + {"global_stride", ProgramUniformVariableDataType::Uint32}, + {"input_output_stride", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; +}; + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + float scale_; + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; + bool is_packed_batching_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc new file mode 100644 index 0000000000000..fe541f58d34ec --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +static uint32_t GetMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("skip", ShaderUsage::UseUniform); + shader.AddInput("gamma", ShaderUsage::UseUniform); + if (hasBeta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + if (hasBias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_input_skip_bias_sum_) { + shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseUniform); + } + + int components = x.NumComponents(); + + std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; + std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : ""; + + shader.AdditionalImplementation() + << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << input_skip_bias_sum + << " let f32_value = f32_val_t(value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + + return Status::OK(); +} + +template +Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* x = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* gamma = context.Input(2); + // optional + const Tensor* beta = context.Input(3); + const Tensor* bias = context.Input(4); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + auto* input_skip_bias_sum = context.Output(3, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + const uint32_t hidden_size = gsl::narrow(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; + + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified}; + program + .CacheHint(simplified, has_input_skip_bias_sum) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(gsl::narrow(ceil(1.0 * data_size / hidden_size))) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(hidden_size)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (beta != nullptr) { + program.AddInput({beta, ProgramTensorMetadataDependency::Type, components}); + } + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + if (has_input_skip_bias_sum) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + SkipLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h new file mode 100644 index 0000000000000..03de1a4b568b9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class SkipLayerNormProgram final : public Program { + public: + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { + epsilon_ = epsilon; + hasBeta_ = hasBeta; + hasBias_ = hasBias; + epsilon_ = epsilon; + hidden_size_ = hidden_size; + has_input_skip_bias_sum_ = has_input_skip_bias_sum; + simplified_ = simplified; + is_fp16_ = is_fp16; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool hasBeta_; + bool hasBias_; + float epsilon_; + uint32_t hidden_size_; + bool has_input_skip_bias_sum_; + bool is_fp16_; + bool simplified_; +}; + +template +class SkipLayerNorm final : public WebGpuKernel { + public: + SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..9a49adf347a29 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -0,0 +1,461 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { +// Put it to a common place? +uint32_t GetMaxComponents(uint32_t size) { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + + return 1; +} + +std::string QuantizedDataType(int components) { + switch (components) { + case 1: + return "array"; + case 2: + return "mat4x2"; + case 4: + return "mat2x4"; + default: + return "array"; + } +} + +constexpr unsigned int kMinMForTileOptimization = 4; +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + MatMulNBits); + +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + if ((is_intel_ || tile_m_ > 1) && block_size_ == 32) { + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. + const uint32_t a_length_per_tile = tile_size / a.NumComponents(); + const uint32_t blocks_per_tile = tile_size / block_size_; + if (tile_m_ == 1) { + shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" + " if (col < uniforms.input_a_shape[2]) {\n" + << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" + << " } else {\n" + " return input_a_value_t(0);\n" + " }\n" + "}\n" + << "var sub_a: array;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; + std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n"; + } else { + ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); + ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); + + shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" + " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" + << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" + << " } else {\n" + " return input_a_value_t(0);\n" + " }\n" + "}\n" + << "var sub_a: array," << tile_m_ << ">;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; + shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" + << " let row = workgroup_id.y * " << tile_m_ << ";\n" + << " let batch = workgroup_id.z;\n"; + } + shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" + // Loop over shared dimension. + << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" + << " let a_col_start = tile * " << a_length_per_tile << ";\n" + << " // load one tile A data into shared memory.\n" + << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" + << " let a_col = a_col_start + a_offset;\n"; + if (tile_m_ == 1) { + shader.MainFunctionBody() << " sub_a[a_offset] = mm_readA(batch, row, a_col);\n"; + } else { + for (uint32_t i = 0; i < tile_m_; i++) { + shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; + } + } + shader.MainFunctionBody() << " }\n" + " workgroupBarrier();\n" + // Each thread processes one block. + " let b_row = col + local_id.y;\n" + << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" + " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; + } else { + // The default zero point is 8 for unsigned 4-bit quantization. + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + } + shader.MainFunctionBody() << " var scale = output_element_t(0);\n" + " var b_data = input_b_value_t(0);\n" + << " if (block < n_blocks_per_col) {\n" + << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" + << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" + << " }\n" + << " var word_offset = local_id.x * " << block_size_ / a.NumComponents() << ";\n" + << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + shader.MainFunctionBody() << " let b_value = b_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" + " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" + " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + " let b_dequantized_values = (b_quantized_values - mat2x4("; + for (int i = 0; i < 8; i++) { + shader.MainFunctionBody() << "zero_point"; + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale;\n"; + if (tile_m_ == 1) { + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } + } else { + for (uint32_t i = 0; i < tile_m_; i++) { + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } + } + } + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " workgroupBarrier();\n" + " }\n"; + if (tile_m_ == 1) { + shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_idx][b];\n" + " }\n" + " if (col + local_idx < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" + << " }\n" + " }\n"; + } else { + shader.MainFunctionBody() << " if (local_id.y < " << tile_m_ << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_id.y][local_id.x][b];\n" + " }\n" + " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" + << " }\n" + " }\n"; + } + } else { + const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); + const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2];\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; + + // prepare scale and zero point + shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " var zero_point_byte_count: u32;\n" + " var zero_point_word_index: u32;\n" + " var zero_point_byte_offset: u32;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " var zero_point_bits_offset: u32;\n" + " var zero_point_word: u32;\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" + << " col_index += 1;\n"; + } + } else { + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " col_index += 1;\n"; + } + } + + shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; + + // prepare b data + shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; + } + shader.MainFunctionBody() << " var b_value : u32;\n" + " let b_mask : u32 = 0x0F0F0F0Fu;\n" + " var b_value_lower : vec4;\n" + " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + + // process one word + shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + " } else {\n" + " a_data[j] = input_a_value_t(0);\n" + " }\n" + " }\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " b_value = b" << c << "_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " b_value_lower = unpack4xU8(b_value & b_mask);\n" + " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; + if (a.NumComponents() == 1) { + if (has_zero_points_) { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + } else { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + } + } else { + shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; + for (int i = 0; i < 8; i++) { + if (has_zero_points_) { + shader.MainFunctionBody() << "zero_point" << c; + } else { + shader.MainFunctionBody() << "zero_point"; + } + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale" << c << ";\n"; + } + + shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + if (y.NumComponents() > 1) { + shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; + } + shader.MainFunctionBody() << " += "; + if (a.NumComponents() == 1) { + shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " + "a_data[1] * b_dequantized_values[1] + " + "a_data[2] * b_dequantized_values[2] + " + "a_data[3] * b_dequantized_values[3] + " + "a_data[4] * b_dequantized_values[4] + " + "a_data[5] * b_dequantized_values[5] + " + "a_data[6] * b_dequantized_values[6] + " + "a_data[7] * b_dequantized_values[7];\n"; + } else if (a.NumComponents() == 2) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]) + " + "dot(a_data[2], b_dequantized_values[2]) + " + "dot(a_data[3], b_dequantized_values[3]);\n"; + } else if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]);\n"; + } + } + + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + << " if (local_id.x < " << output_number_ << ") {\n" + << " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" + << " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + << " workgroup_shared_offset += " << output_number_ << ";\n" + << " }\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" + << " }\n"; + } + + return Status::OK(); +} + +Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* b = context.Input(1); + const Tensor* scales = context.Input(2); + const Tensor* zero_points = context.Input(3); + const Tensor* g_idx = context.Input(4); + const Tensor* bias = context.Input(5); + + ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); + ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + auto* y = context.Output(0, helper.OutputShape()); + const uint32_t data_size = gsl::narrow(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const uint32_t batch_count = gsl::narrow(helper.OutputOffsets().size()); + const uint32_t M = gsl::narrow(helper.M()); + const uint32_t N = gsl::narrow(helper.N()); + const uint32_t K = gsl::narrow(helper.K()); + const uint32_t block_size = gsl::narrow(block_size_); + constexpr uint32_t nbits = 4; + + const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; + const uint32_t blob_size = (block_size / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + uint32_t components = GetMaxComponents(N); + + const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} && + context.AdapterInfo().architecture == std::string_view{"gen-12lp"}; + const bool has_zero_points = zero_points != nullptr; + + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + constexpr uint32_t output_number = 1; + const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; + MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, is_intel}; + if (M > kMinMForTileOptimization && block_size == 32) { + components = 1; + constexpr uint32_t workgroup_size = 64; + constexpr uint32_t workgroup_y = 8; + constexpr uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, + (M + tile_m - 1) / tile_m, + batch_count); + program.CacheHint("T_M" + std::to_string(tile_m)); + } else if (is_intel && block_size == 32) { + components = 1; + constexpr uint32_t workgroup_size = 128; + const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 + : 1; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize(data_size / components / workgroup_y); + program.CacheHint("T_M" + std::to_string(tile_m)); + } else { + program.SetDispatchGroupSize(data_size / components / output_number); + program.CacheHint("O_N" + std::to_string(output_number)); + } + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddUniformVariable({block_size}); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..8a4626083419c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsProgram final : public Program { + public: + MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool is_intel) : Program{"MatMulNBits"}, + output_number_{output_number}, + block_size_{block_size}, + tile_m_{tile_m}, + components_b_{components_b}, + has_zero_points_{has_zero_points}, + is_intel_{is_intel} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t output_number_; + uint32_t block_size_; + uint32_t tile_m_; + int components_b_; + bool has_zero_points_; + bool is_intel_; +}; + +class MatMulNBits final : public WebGpuKernel { + public: + MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + int64_t bits = info.GetAttr("bits"); + ORT_ENFORCE(bits == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 8ed1372cd0e62..2e7ed5a16a2f0 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -9,6 +9,24 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -18,7 +36,22 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - }; + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/codegen/common/common.cc b/onnxruntime/core/codegen/common/common.cc deleted file mode 100644 index 818b919e99ef2..0000000000000 --- a/onnxruntime/core/codegen/common/common.cc +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/common.h" - -#include "core/framework/tensorprotoutils.h" -#include "core/common/inlined_containers.h" -#include "core/graph/graph.h" -#include "core/graph/schema_registry.h" -#include -#include - -namespace onnxruntime { - -NodeKey GetKey(const onnxruntime::Node* node) { - ORT_ENFORCE(nullptr != node); - ORT_ENFORCE(node->OutputDefs().size() > 0); - return node->OutputDefs()[0]->Name(); -} - -NodeKey GetKey(const onnxruntime::Node& node) { - ORT_ENFORCE(node.OutputDefs().size() > 0); - return node.OutputDefs()[0]->Name(); -} - -NodeKey GetKey(const onnxruntime::NodeArg* def) { - // NodeArg's name is unique. - ORT_ENFORCE(nullptr != def); - return def->Name(); -} - -bool IsRecurrentNode(const onnxruntime::Node& node) { - auto op_type = node.OpType(); - return (op_type == "LSTM" || op_type == "RNN" || op_type == "GRU" || - op_type == "Scan" || op_type == "Loop"); -} - -bool IsAliasNode(const onnxruntime::Node& node) { - auto op_type = node.OpType(); - if (op_type == "Transpose") { - // Treat Transpose (1,N) -> (N,1) as Alias - const auto shape = node.OutputDefs()[0]->Shape(); - if (shape != nullptr && shape->dim_size() == 2) { - for (int i = 0; i < 2; ++i) { - if (shape->dim(i).has_dim_value() && shape->dim(i).dim_value() == 1) { - return true; - } - } - } - return false; - } - - return (op_type == "Flatten" || op_type == "Identity" || op_type == "Reshape" || - op_type == "Squeeze" || op_type == "Unsqueeze"); -} - -std::string NormalizeCppName(const std::string& name) { - std::string normalized_name = name; - for (char c : {'.', ' ', '+', '-', '*', '/', '\\', '='}) - std::replace(normalized_name.begin(), normalized_name.end(), c, '_'); - return normalized_name; -} - -std::string NormalizeNodeArgName(const NodeArg* def) { - return NormalizeCppName(def->Name()); -} - -bool IsFusedNode(const Node& node) { - if (node.NodeType() == Node::Type::Fused) { - return true; - } - return false; -} - -// A unified API to get Subgraph -const Graph* GetSubgraph(const Node& node) { - if (node.NodeType() == Node::Type::Fused) { - return &(node.GetFunctionBody()->Body()); - } else if (node.OpType() == "Scan") { - return node.GetGraphAttribute("body"); - } - // return nullptr implying no subgraph - return nullptr; -} - -bool HasLoop(const Node& node) { - auto op_type = node.OpType(); - if (op_type == "LSTM" || - op_type == "GRU" || - op_type == "RNN" || - op_type == "Scan") { - return true; - } - return false; -} - -// Return the corresponding input node for the NodeArg of the given node -const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) { - const auto& input_name = def->Name(); - const onnxruntime::Node* input_node = nullptr; - // search input node set to see if input_name is in their outputs (weights are not from node) - for (auto iter = node.InputNodesBegin(); iter != node.InputNodesEnd(); ++iter) { - const onnxruntime::Node& p = *iter; - bool found = false; - ORT_THROW_IF_ERROR(p.ForEachWithIndex( - p.OutputDefs(), - [&found, &input_name](const onnxruntime::NodeArg& out_def, size_t) { - if (input_name == out_def.Name()) { - found = true; - } - return Status::OK(); - })); - if (found) - input_node = &p; - } - return input_node; -} - -// create capacity from subgraph -std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, - int fused_count, - std::unique_ptr& subgraph) { - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "Fuse" + std::to_string(fused_count); - meta_def->domain = "Fuse"; - - std::set node_indices(subgraph->nodes.begin(), subgraph->nodes.end()); - - const auto& start_node_index = subgraph->nodes.front(); - const auto& start_node = *graph.GetNode(start_node_index); - const auto& end_node_index = subgraph->nodes.back(); - const auto& end_node = *graph.GetNode(end_node_index); - meta_def->name += start_node.OpType() + std::to_string(start_node_index); - meta_def->name += "_With" + std::to_string(subgraph->nodes.size()) + "Nodes_"; - meta_def->name += end_node.OpType() + std::to_string(end_node_index); - - InlinedHashSet real_output_names; - real_output_names.reserve(graph.GetOutputs().size()); - for (const auto* def : graph.GetOutputs()) { - real_output_names.insert(def->Name()); - } - - for (const auto& node_index : subgraph->nodes) { - const auto& node = *graph.GetNode(node_index); - auto process_input_fn = - [&meta_def, &node, &node_indices](const onnxruntime::NodeArg& def, size_t) { - const onnxruntime::Node* input_node = GetInputNode(node, &def); - bool input_from_subgraph = (input_node && node_indices.count(input_node->Index())); - if (!input_from_subgraph) { - // input is from weights or outside of graph - meta_def->inputs.push_back(def.Name()); - } - return Status::OK(); - }; - // handle current graph's inputs - ORT_THROW_IF_ERROR(node.ForEachWithIndex(node.InputDefs(), process_input_fn)); - // nodes' implicit inputs also need to be collected. They need to - // be promoted to being explicit inputs for everything to work. - ORT_THROW_IF_ERROR(node.ForEachWithIndex(node.ImplicitInputDefs(), process_input_fn)); - - // Handle outouts - // two cases are considered as outputs - // 1. Output NodeArg is not used by any Node - // 2. Output NodeArg is used by at least one Node out of this subgraph. - // Note a NodeArg can be used by Nodes in and out of the subgraph at the same time. - // 3. Output NodeArg is one of real outputs of an Ort graph. - - auto InsertOutputToSubgraph = [&meta_def](const NodeArg* def) { - if (std::find(meta_def->outputs.begin(), meta_def->outputs.end(), def->Name()) == - meta_def->outputs.end()) { - meta_def->outputs.push_back(def->Name()); - } - }; - - InlinedHashSet input_names_from_the_output_node; - - for (auto o_iter = node.OutputEdgesBegin(); o_iter != node.OutputEdgesEnd(); ++o_iter) { - const auto& p = *o_iter; - const Node& out_node = p.GetNode(); - - // preprocess for the case 1 - ORT_THROW_IF_ERROR(out_node.ForEachWithIndex( - out_node.InputDefs(), - [&input_names_from_the_output_node](const onnxruntime::NodeArg& in_def, size_t) { - input_names_from_the_output_node.insert(in_def.Name()); - return Status::OK(); - })); - - // handle the case 2 - if (node_indices.count(out_node.Index()) == 0) { - const NodeArg* def = node.OutputDefs()[p.GetSrcArgIndex()]; - InsertOutputToSubgraph(def); - } - } - - // handle case 1 and 3 - ORT_THROW_IF_ERROR(node.ForEachWithIndex( - node.OutputDefs(), - [&](const onnxruntime::NodeArg& def, size_t) { - if (input_names_from_the_output_node.count(def.Name()) == 0 || - real_output_names.count(def.Name()) > 0) { - InsertOutputToSubgraph(&def); - } - return Status::OK(); - })); - } - - // Handle subgraph's initializers - const auto& all_initializers = graph.GetAllInitializedTensors(); - for (const auto& node_index : subgraph->nodes) { - const auto& node = *graph.GetNode(node_index); - // check whether it is an immediate nested subgraph - auto immediate_nested_subgraph = GetSubgraph(node); - // If so, copy the immediate nested subgraph's initializers to meta_def->inputs. - // Note we don't need recursion here, since Ort did recursion for us by handling subgraph early than the current graph. - // Therefore, the all inner nested subgraph's initializers should be already in the immediate nested subgraph's inputs. - if (nullptr != immediate_nested_subgraph) { - for (auto& n : immediate_nested_subgraph->Nodes()) { - auto add_input_fn = - [&meta_def, &all_initializers](const onnxruntime::NodeArg& def, size_t) { - auto iter = all_initializers.find(def.Name()); - if (iter != all_initializers.end()) { - meta_def->inputs.push_back(def.Name()); - } - return Status::OK(); - }; - ORT_THROW_IF_ERROR(n.ForEachWithIndex(n.InputDefs(), add_input_fn)); - ORT_THROW_IF_ERROR(n.ForEachWithIndex(n.ImplicitInputDefs(), add_input_fn)); - } - } - } - - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - std::unique_ptr finished_subgraph(subgraph.release()); - finished_subgraph->SetMetaDef(std::move(meta_def)); - return std::make_unique(std::move(finished_subgraph)); -} - -int64_t ShapeRank(const NodeArg* def) { - ORT_ENFORCE_DEBUG(nullptr != def); - return gsl::narrow_cast(def->Shape()->dim_size()); -} - -bool ShapeHasValue(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(nullptr != def); - ORT_ENFORCE_DEBUG(i >= 0); - ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size()); - return utils::HasDimValue(def->Shape()->dim(i)); -} - -bool ShapeHasSymbol(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(nullptr != def); - ORT_ENFORCE_DEBUG(i >= 0); - ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size()); - return utils::HasDimParam(def->Shape()->dim(i)); -} - -int64_t ShapeValue(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(ShapeHasValue(def, i)); - return def->Shape()->dim(i).dim_value(); -} - -const std::string& ShapeSymbol(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(ShapeHasSymbol(def, i)); - return def->Shape()->dim(i).dim_param(); -} - -ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def) { - ORT_ENFORCE_DEBUG(nullptr != def); - return static_cast(def->TypeAsProto()->tensor_type().elem_type()); -} - -// Convert GraphNodes to internal NodePtrs without check lifetime. -// Please use it only locally when GraphNodes still exist -InlinedVector ConvertGraphNodesToNodePtrs(const ConstGraphNodes& graph_nodes) { - InlinedVector nodes; - for (auto& node : graph_nodes) { - nodes.push_back(&node); - } - return nodes; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/common.h b/onnxruntime/core/codegen/common/common.h deleted file mode 100644 index 81b74daf6f711..0000000000000 --- a/onnxruntime/core/codegen/common/common.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/inlined_containers.h" -#include "core/framework/compute_capability.h" -#include "core/framework/tensor.h" -#include "core/graph/graph_nodes.h" -#include "core/graph/graph_viewer.h" - -#ifndef NDEBUG -#define ORT_ENFORCE_DEBUG(...) ORT_ENFORCE(__VA_ARGS__) -#else -#define ORT_ENFORCE_DEBUG(...) -#endif // !NDEBUG - -// DYN_PROMOTE is a simplified llvm::dyn_cast, which does not need RTTI -// DYN_PROMOTE is faster than dynamic_cast and also has smaller binary size -// Please use DYN_PROMOTE in a critical path. -#define DYN_PROMOTE(BASE) \ - template \ - inline const ToType* Promote(const BASE* base) { \ - if (ToType::IsType(base)) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(BASE* base) { \ - if (ToType::IsType(base)) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(const std::unique_ptr& base) { \ - if (ToType::IsType(base.get())) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(const std::shared_ptr& base) { \ - if (ToType::IsType(base.get())) \ - return static_cast(base); \ - return nullptr; \ - } - -// DYN_PROMOTE_BASE is a macro inserted in the base class to support DYN_PROMOTE -// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class -// TYPE_ID_VAR is a corresponding variable name for in the base class -#define DYN_PROMOTE_BASE(BASE, TYPE_ID, TYPE_ID_VAR) \ - inline const TYPE_ID TypeID() const { \ - return TYPE_ID_VAR; \ - } \ - \ - static inline bool IsType(const BASE*) { \ - return true; \ - } - -// DYN_PROMOTE_DERIVED is a macro inserted in a derived class to support DYN_PROMOTE -// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class -// TYPE_ID_VALUE is corresponding TYPE_ID::value of a derived class. -#define DYN_PROMOTE_DERIVED(BASE, TYPE_ID, TYPE_ID_VALUE) \ - static inline bool IsType(const BASE* base) { \ - ORT_ENFORCE_DEBUG(nullptr != base); \ - return base->TypeID() == TYPE_ID::TYPE_ID_VALUE; \ - } - -// DYNAMIC_PROMOTE is a dynamic_cast needing RTTI -// DYNAMIC_PROMOTE is usually slower than than DYN_PROMOTE. -// Please use DYNAMIC_PROMOTE in a non-critical path. -#define DYNAMIC_PROMOTE(BASE) \ - template \ - inline const X* Promote(const BASE* base) { \ - auto derived = dynamic_cast(base); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(BASE* base) { \ - auto derived = dynamic_cast(base); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(const std::unique_ptr& base) { \ - auto derived = dynamic_cast(base.get()); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(const std::shared_ptr& base) { \ - auto derived = dynamic_cast(base.get()); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } - -namespace onnxruntime { - -// Nodekey is used as a key for maps -using NodeKey = std::string; - -NodeKey GetKey(const onnxruntime::Node* node); -NodeKey GetKey(const onnxruntime::Node& node); -NodeKey GetKey(const onnxruntime::NodeArg* def); - -bool IsRecurrentNode(const onnxruntime::Node& node); - -bool IsAliasNode(const onnxruntime::Node& node); - -// Helper function that creates ComputeCapability for subgraphs -std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, - int fused_count, - std::unique_ptr& subgraph); - -bool IsFusedNode(const Node& node); - -bool HasLoop(const Node& node); - -const Graph* GetSubgraph(const Node& node); - -std::string NormalizeCppName(const std::string& name); - -std::string NormalizeNodeArgName(const NodeArg* def); - -// Return the corresponding input node for the NodeArg of the given node -const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def); - -int64_t ShapeRank(const NodeArg* def); - -bool ShapeHasValue(const NodeArg* def, int i); - -bool ShapeHasSymbol(const NodeArg* def, int i); - -int64_t ShapeValue(const NodeArg* def, int i); - -const std::string& ShapeSymbol(const NodeArg* def, int i); - -ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def); - -// Convert ConstGraphNodes to internal NodePtrs without check lifetime. -// Please use it only locally when GraphNodes still exist -InlinedVector ConvertGraphNodesToNodePtrs(const ConstGraphNodes& graph_nodes); - -enum : int { - Dimension_Unknown = -1, -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/creator.h b/onnxruntime/core/codegen/common/creator.h deleted file mode 100644 index b31a12db4875b..0000000000000 --- a/onnxruntime/core/codegen/common/creator.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/dispatcher.h" - -// TODO rename this file to creator_base -namespace onnxruntime { -namespace codegen { - -// It is a base class for TVM Op IR builder, weight layout builder, TVM scheduler -// CreatorBase is a template class of compiler pass -// for 1) TVM IR builder -// 2) Weight layout transformer -// 3) TVM Scheduler, etc. -// CreatorBase is similor to OpXXCreate in llvm IR builder - -template -class CreatorBase { - public: - CreatorBase(const std::string& name) - : name_(name) {} - - virtual ~CreatorBase() = default; - - virtual RETURN_TYPE Evaluate(INPUT_TYPE, - NODE_TYPE, - CONTEXT_TYPE, - OUTPUT_TYPE) = 0; - - const std::string& Name() const { - return name_; - } - - protected: - std::string name_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CreatorBase); -}; - -// macro to stringize -#define STRINGIZE_NX(OP) #OP -#define STRINGIZE(OP) STRINGIZE_NX(OP) - -// macro returns class name -#define CREATOR_CLASS(OP, POSTFIX) \ - OP##POSTFIX - -// macro returns class name as string -#define CREATOR_STRING(OP, POSTFIX) \ - STRINGIZE(CREATOR_CLASS(OP, POSTFIX)) - -// macro returns class constructor name -#define CREATOR_CLASS_FUNC(OP, POSTFIX) \ - OP##POSTFIX() - -// macro declares a creator class inheriting the template class CreatorBase -// with corresponding template parameters -#define DECLARE_CREATOR_CLASS(OP, POSTFIX, INPUT, NODE, CONTEXT, OUTPUT, RETURN) \ - class CREATOR_CLASS(OP, POSTFIX) : public onnxruntime::codegen::CreatorBase { \ - public: \ - CREATOR_CLASS_FUNC(OP, POSTFIX) : CreatorBase(CREATOR_STRING(OP, POSTFIX)) {} \ - RETURN Evaluate(INPUT, \ - NODE, \ - CONTEXT, \ - OUTPUT) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CREATOR_CLASS(OP, POSTFIX)); \ - }; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/dispatcher.h b/onnxruntime/core/codegen/common/dispatcher.h deleted file mode 100644 index 80a854a06977c..0000000000000 --- a/onnxruntime/core/codegen/common/dispatcher.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { -namespace codegen { - -// DispatcherBase is a customized unordered_map -// that provides all codegen-related functionality -// including 1) dispatching a pass -// 2) dump corresponding name -// DispatcherBase may or may not keep ownership, -// depending on the template parameter, CONTENT_TYPE. -// Note DispatcherBase has a protected destructor - -template -class DispatcherBase { - public: - DispatcherBase(const std::string& name) - : name_(name) {} - - const std::string& Name() const { - return name_; - } - - bool Contains(const std::string& name) const { - return contents_.count(name) > 0; - } - - void ForEach(std::function - func) { - for (auto& p : contents_) { - func(p.first, p.second); - } - } - - bool Register(const std::string& name, - CONTENT_TYPE op) { - if (!Contains(name)) { - contents_.emplace(name, op); - return true; - } - return false; - } - - CONTENT_TYPE Get(const std::string& key) const { - auto iter = contents_.find(key); - if (iter != contents_.end()) { - return iter->second; - } - return nullptr; - } - - const std::unordered_map GetContents() const { - return contents_; - } - - std::unordered_map GetMutableContents() { - return contents_; - } - - protected: - std::string name_; - std::unordered_map contents_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DispatcherBase); - ~DispatcherBase() = default; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/dump_array.h b/onnxruntime/core/codegen/common/dump_array.h deleted file mode 100644 index 8e51cd36d0087..0000000000000 --- a/onnxruntime/core/codegen/common/dump_array.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include -#include -#include -#include - -namespace onnxruntime { - -template -void DumpArrayRecursive(const T1* data, int64_t& data_offset, const std::vector& shape, int idx) { - int dim = static_cast(shape.size()); - if (dim == 0) { - std::cout << "[]\n"; - return; - } - - assert(idx < dim); - int sz = shape[idx]; - - std::cout << "["; - if (idx < dim - 1) { - for (auto i = 0; i < sz; ++i) { - DumpArrayRecursive(data, data_offset, shape, idx + 1); - if (i < sz - 1) { - std::cout << ","; - // print multiple newlines after ',' when necessary - for (int j = idx + 1; j < dim; j++) - std::cout << "\n"; - // print leading spaces before "[" when necessary - for (int j = 0; j < idx + 1; ++j) - std::cout << " "; - } - } - } else { - for (auto i = 0; i < sz; ++i) { - if (std::is_same::value || std::is_same::value) - std::cout << std::setw(3) << static_cast(*(data + data_offset)); - else - std::cout << std::setw(12) << std::setprecision(8) << *(data + data_offset); - data_offset++; - if (i < sz - 1) - std::cout << ","; - } - } - std::cout << "]"; -} - -// A helper function to dump multidimensional arrays in a way similar to numpy -template -void DumpArray(const std::string& tag, const T1* data, const std::vector& shape) { - std::cout << tag << "\n"; - int64_t data_offset = 0; - DumpArrayRecursive(data, data_offset, shape, 0); - assert(data_offset == TotalSize(shape)); - std::cout << std::endl; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/handle.h b/onnxruntime/core/codegen/common/handle.h deleted file mode 100644 index 7caad27dcbe01..0000000000000 --- a/onnxruntime/core/codegen/common/handle.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/target_info.h" -#include -#include - -namespace onnxruntime { -namespace codegen { - -using DomainVersionLookupFunc = std::function; - -struct CodeGenHandle { - CodeGenTarget* codegen_target; - DomainVersionLookupFunc domain_version_lookup_func = - // by default, always uses the latest opset implemented - [](const std::string&) { return INT_MAX; }; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/op_macro.h b/onnxruntime/core/codegen/common/op_macro.h deleted file mode 100644 index 04305c4aa47b0..0000000000000 --- a/onnxruntime/core/codegen/common/op_macro.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -namespace onnxruntime { - -#define LIST_BINARY_OPS() \ - BINARY_OP(Add) \ - BINARY_OP(Div) \ - BINARY_OP(Mul) \ - BINARY_OP(PRelu) \ - BINARY_OP(Sub) - -#define LIST_BINARY_CMP_OPS() \ - BINARY_CMP_OP(Equal) \ - BINARY_CMP_OP(Greater) \ - BINARY_CMP_OP(Less) - -#define LIST_POOL_OPS() \ - POOL_OP(MaxPool) \ - POOL_OP(AveragePool) \ - POOL_OP(GlobalMaxPool) \ - POOL_OP(GlobalAveragePool) - -#define LIST_REDUCE_OPS() \ - REDUCE_INDEXED_OP(ArgMax) \ - REDUCE_INDEXED_OP(ArgMin) \ - REDUCE_OP(ReduceL1) \ - REDUCE_OP(ReduceL2) \ - REDUCE_OP(ReduceLogSum) \ - REDUCE_OP(ReduceLogSumExp) \ - REDUCE_OP(ReduceMax) \ - REDUCE_OP(ReduceMean) \ - REDUCE_OP(ReduceMin) \ - REDUCE_OP(ReduceProd) \ - REDUCE_OP(ReduceSum) \ - REDUCE_OP(ReduceSumSquare) - -#define LIST_UNARY_OPS() \ - UNARY_OP(Abs) \ - UNARY_OP(Affine) \ - UNARY_OP(Ceil) \ - UNARY_OP(Elu) \ - UNARY_OP(Exp) \ - UNARY_OP(Floor) \ - UNARY_OP(HardSigmoid) \ - UNARY_OP(LeakyRelu) \ - UNARY_OP(Log) \ - UNARY_OP(Neg) \ - UNARY_OP(ParametricSoftplus) \ - UNARY_OP(Reciprocal) \ - UNARY_OP(Relu) \ - UNARY_OP(ScaledTanh) \ - UNARY_OP(Selu) \ - UNARY_OP(Sigmoid) \ - UNARY_OP(Softplus) \ - UNARY_OP(Softsign) \ - UNARY_OP(Sqrt) \ - UNARY_OP(Tanh) \ - UNARY_OP(ThresholdedRelu) - -#define LIST_VARIADIC_OPS() \ - VARIADIC_OP(Max) \ - VARIADIC_OP(Min) \ - VARIADIC_OP(Sum) - -#define LIST_ALL_GENERIC_OPS() \ - LIST_BINARY_OPS() \ - LIST_BINARY_CMP_OPS() \ - LIST_REDUCE_OPS() \ - LIST_POOL_OPS() \ - LIST_UNARY_OPS() \ - LIST_VARIADIC_OPS() \ - ADD_OP_ITEM(Cast) \ - ADD_OP_ITEM(Clip) \ - ADD_OP_ITEM(Concat) \ - ADD_OP_ITEM(Conv) \ - ADD_OP_ITEM(Crop) \ - ADD_OP_ITEM(Dropout) \ - ADD_OP_ITEM(Expand) \ - ADD_OP_ITEM(Flatten) \ - ADD_OP_ITEM(Gather) \ - ADD_OP_ITEM(GatherElements) \ - ADD_OP_ITEM(Gemm) \ - ADD_OP_ITEM(Identity) \ - ADD_OP_ITEM(LogSoftmax) \ - ADD_OP_ITEM(LSTM) \ - ADD_OP_ITEM(MatMul) \ - ADD_OP_ITEM(MatMulInteger) \ - ADD_OP_ITEM(Pad) \ - ADD_OP_ITEM(Reshape) \ - ADD_OP_ITEM(Shape) \ - ADD_OP_ITEM(Slice) \ - ADD_OP_ITEM(Softmax) \ - ADD_OP_ITEM(Split) \ - ADD_OP_ITEM(Squeeze) \ - ADD_OP_ITEM(Transpose) \ - ADD_OP_ITEM(Unsqueeze) \ - ADD_OP_ITEM(Where) - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/profile.h b/onnxruntime/core/codegen/common/profile.h deleted file mode 100644 index 31c9e764320d0..0000000000000 --- a/onnxruntime/core/codegen/common/profile.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -// uncomment this line or use -DCODEGEN_ENABLE_PROFILER in compiler options to enable profiler events in codegen -// #define CODEGEN_ENABLE_PROFILER - -#ifdef CODEGEN_ENABLE_PROFILER -#include "core/common/profiler.h" - -namespace onnxruntime { - -class ProfilerEvent { - public: - ProfilerEvent(const std::string& name) : name_(name) { - ts_ = profiling::Profiler::Instance().StartTime(); - } - - ~ProfilerEvent() { - profiling::Profiler::Instance().EndTimeAndRecordEvent(profiling::EventCategory::NODE_EVENT, name_, ts_); - } - - private: - TimePoint ts_; - const std::string name_; -}; - -} // namespace onnxruntime - -#define CODEGEN_PROFILER_EVENT(name) onnxruntime::ProfilerEvent profiler_event(name) - -#else - -#define CODEGEN_PROFILER_EVENT(name) - -#endif diff --git a/onnxruntime/core/codegen/common/registry.h b/onnxruntime/core/codegen/common/registry.h deleted file mode 100644 index c1642e76e2120..0000000000000 --- a/onnxruntime/core/codegen/common/registry.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { -namespace codegen { - -// RegistryBase is a customized unordered_map -// that keep ownership of passes, -// including 1) IR builder passes -// 2) Weight layout transformer passes -// 3) Scheduler passses, etc. - -template -class RegistryBase { - public: - RegistryBase() = default; - - virtual ~RegistryBase() = default; - - bool Contains(const std::string& name) const { - return contents_.count(name) > 0; - } - - CONTENT_TYPE* Get(const std::string& name) const { - if (contents_.find(name) != contents_.end()) - return contents_.at(name).get(); - return nullptr; - } - - CONTENT_TYPE* RegisterOrGet( - const std::string& name, - std::unique_ptr&& ptr) { - if (!Contains(name)) - contents_.emplace(name, std::move(ptr)); - return Get(name); - } - - CONTENT_TYPE* RegisterOrGet( - std::unique_ptr&& ptr) { - return RegisterOrGet(ptr->Name(), std::move(ptr)); - } - - bool Register( - const std::string& name, - std::unique_ptr&& ptr) { - if (!Contains(name)) { - contents_.emplace(name, std::move(ptr)); - return true; - } - return false; - } - - bool Register( - std::unique_ptr&& ptr) { - return Register(ptr->Name(), std::move(ptr)); - } - - protected: - std::unordered_map> contents_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RegistryBase); -}; - -// Put common Registry Management utilities if these is any - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/settings.cc b/onnxruntime/core/codegen/common/settings.cc deleted file mode 100644 index 529cb654f922c..0000000000000 --- a/onnxruntime/core/codegen/common/settings.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/settings.h" - -#include "core/common/logging/logging.h" -#include -#include - -namespace onnxruntime { -namespace codegen { - -CodeGenSettings& CodeGenSettings::Instance() { - static CodeGenSettings settings; - return settings; -} - -CodeGenSettings::CodeGenSettings() {} - -void CodeGenSettings::InsertOptions(const std::map& options) { - for (const auto& option : options) { - const auto& key = option.first; - const auto& value = option.second; - - auto iter = options_.find(key); - // found existing ones - if (iter != options_.end()) { - if (iter->second != value) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "CodeGenSettings: option" - << key << " is overridded from: " - << iter->second << " to: " << value; - iter->second = value; - } - } else { - options_.insert(std::make_pair(key, value)); - } - } -} - -void CodeGenSettings::DumpOptions() const { - std::ostringstream stream; - stream << "CodeGenSettings: dump all options" << std::endl; - for (const auto& option : options_) { - stream << " " << option.first << " = " << option.second << std::endl; - } - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); -} - -std::string CodeGenSettings::GetOptionValue(const std::string& key) const { - const auto& iter = options_.find(key); - if (iter == options_.end()) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "CodeGenSettings::GetOptionValue: unrecognized option" << key; - return ""; - } - return iter->second; -} - -bool CodeGenSettings::HasOption(const std::string& key) const { - return options_.count(key) > 0; -} - -bool CodeGenSettings::OptionMatches(const std::string& key, const std::string& value) const { - if (!HasOption(key)) - return false; - -#ifdef _WIN32 - return 0 == _stricmp(options_.at(key).c_str(), value.c_str()); -#else - return 0 == strcasecmp(options_.at(key).c_str(), value.c_str()); -#endif -} - -void CodeGenSettings::Clear() { - options_.clear(); -} - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/settings.h b/onnxruntime/core/codegen/common/settings.h deleted file mode 100644 index e327b0e207cc2..0000000000000 --- a/onnxruntime/core/codegen/common/settings.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace codegen { - -// use log level warning as default to make sure logs are outputted -#define CODEGEN_SETTINGS_LOG_LEVEL WARNING - -// This stores codegen settings to control dumps, execution preference, etc. -// CodeGenSettings could come from command line options or environment variables -// Or could come from a static variables in source code -class CodeGenSettings { - public: - // generic built-in options - constexpr static const char* kDumpAllOptions = "dump_all_options"; - constexpr static const char* kCodeGenDumpModule = "codegen_dump_module"; // dump tvm module - constexpr static const char* kCodeGenDumpLower = "codegen_dump_lower"; // dump lowered func - constexpr static const char* kCodeGenDumpSchedule = "codegen_dump_schedule"; // dump scheduler - - void InsertOptions(const std::map& options); - void DumpOptions() const; - std::string GetOptionValue(const std::string& key) const; - bool HasOption(const std::string& key) const; - bool OptionMatches(const std::string& key, const std::string& value) const; - void Clear(); - static CodeGenSettings& Instance(); - - private: - CodeGenSettings(); - - std::map options_; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/target_info.h b/onnxruntime/core/codegen/common/target_info.h deleted file mode 100644 index da063545f0a1e..0000000000000 --- a/onnxruntime/core/codegen/common/target_info.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { - -// CodeGenTarget holds meta info for backend code generation -// and will be lowered to a target of corresponding backend -// code generation, e.g. TVM's Target. -class CodeGenTarget { - public: - CodeGenTarget() {} - CodeGenTarget(const std::string& target_name) - : target_name_(target_name) {} - - virtual int NaturalVectorWidth(int /*bits*/) const { - return 1; - } - - const std::string& GetTargetName() const { - return target_name_; - } - - virtual ~CodeGenTarget() = default; - - private: - std::string target_name_{"unknown"}; // default name is unknown -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/utils.cc b/onnxruntime/core/codegen/common/utils.cc deleted file mode 100644 index f4140a411bddf..0000000000000 --- a/onnxruntime/core/codegen/common/utils.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/utils.h" -#include "core/common/cpuid_info.h" -#include "core/common/safeint.h" - -#include -#include - -namespace onnxruntime { - -std::unique_ptr GetEnv(const char* var) { - char* val = nullptr; -#if _MSC_VER - size_t len; - - if (_dupenv_s(&val, &len, var)) { - // Something went wrong, just return nullptr. - return nullptr; - } -#else - val = getenv(var); -#endif // _MSC_VER - - if (val == nullptr) { - return nullptr; - } - - // On windows, we will have to explicitly free val. Instead of returning val - // to its caller and make distinguish between windows and linux, we return - // a unique_ptr, and it will be destroyed automatically after the caller - // completes. - size_t len_val = strnlen(val, onnxruntime::kMaxStrLen) + 1; - auto p = std::make_unique(len_val); - // use explicit loop to get ride of VC's warning on unsafe copy - for (size_t i = 0; i < len_val; ++i) { - p[i] = val[i]; - } - return p; -} - -bool IsEnvVarDefined(const char* var) { - auto val = GetEnv(var); - return val != nullptr; -} - -int64_t TotalSize(const std::vector& shape) { - SafeInt total = 1; - for (auto s : shape) { - total *= s; - } - return total; -} - -// Return the strides for the input shape, i.e. the number of -// elements contained by a single element of current dimension. -// For example, for shape[3][4][5][6], strides will be -// [4*5*6, 5*6, 6, 1], i.e. [120, 30, 6, 1] -void GetStrides(const int64_t* shape, int ndim, std::vector& strides) { - strides.resize(ndim); - strides[ndim - 1] = 1; - for (int64_t i = ndim - 2; i >= 0; i--) { - strides[i] = strides[i + 1] * shape[i + 1]; - } -} - -// Common utils to get target option -TargetFeature GetTargetInfo(const codegen::CodeGenSettings& settings) { - TargetFeature feature; - - std::string target_str = ""; - - bool isAVX = false; - bool isAVX2 = false; - bool isAVX512 = false; - if (target_str == "avx") { - isAVX = true; - } else if (target_str == "avx2") { - isAVX = true; - isAVX2 = true; - } else if (target_str == "avx512") { - isAVX = true; - isAVX2 = true; - isAVX512 = true; - } else { - isAVX = CPUIDInfo::GetCPUIDInfo().HasAVX(); - isAVX2 = CPUIDInfo::GetCPUIDInfo().HasAVX2(); - isAVX512 = CPUIDInfo::GetCPUIDInfo().HasAVX512Skylake(); - } - - feature.hasAVX = isAVX; - feature.hasAVX2 = isAVX2; - feature.hasAVX512 = isAVX512; - - return feature; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/utils.h b/onnxruntime/core/codegen/common/utils.h deleted file mode 100644 index ef06b5b72dc2c..0000000000000 --- a/onnxruntime/core/codegen/common/utils.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { - -// Holding utility functions that are not tied to TVM and ORT - -std::unique_ptr GetEnv(const char* var); - -// Check if an environment variable is set -bool IsEnvVarDefined(const char* var); - -int64_t TotalSize(const std::vector& shape); - -void GetStrides(const int64_t* shape, int ndim, std::vector& strides); - -struct TargetFeature { - bool hasAVX; - bool hasAVX2; - bool hasAVX512; -}; - -TargetFeature GetTargetInfo(const codegen::CodeGenSettings& setttings); - -// GCD (Greatest Common Divisor) -template -T GCD(T a, T b) { - ORT_ENFORCE(a >= 0); - ORT_ENFORCE(b >= 0); - if (a < b) std::swap(a, b); - if (b == 0) return a; - while (a % b != 0) { - a = a % b; - std::swap(a, b); - } - return b; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/common.h b/onnxruntime/core/codegen/mti/common.h deleted file mode 100644 index d71e740b9284a..0000000000000 --- a/onnxruntime/core/codegen/mti/common.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#define MTI_ASSERT(condition) \ - if (!(condition)) { \ - std::string error_msg = "Not satisfied: " #condition \ - ": line " + \ - std::to_string(__LINE__) + \ - " in file " + std::string(__FILE__) + "\n"; \ - throw std::runtime_error(error_msg); \ - } diff --git a/onnxruntime/core/codegen/mti/debug/tvm_print.cc b/onnxruntime/core/codegen/mti/debug/tvm_print.cc deleted file mode 100644 index 0491636032b47..0000000000000 --- a/onnxruntime/core/codegen/mti/debug/tvm_print.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/debug/tvm_print.h" - -#include "core/codegen/common/utils.h" -#include "core/codegen/common/dump_array.h" -#include "core/codegen/mti/common.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.print") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* /*ret*/) { - DLTensor* X = args[0]; - DLTensor* Y = args[1]; - - DLDataType dtype = X->dtype; - std::vector shape; - int64_t total_size = 1; - for (int i = 0; i < X->ndim; ++i) { - shape.push_back(X->shape[i]); - total_size *= X->shape[i]; - } - - // pass X to Y - memcpy(static_cast(Y->data) + Y->byte_offset, - static_cast(X->data) + X->byte_offset, - total_size * dtype.bits / 8); - - if (tvm::runtime::TypeMatch(dtype, kDLFloat, 32)) { - float* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("float tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 8)) { - int8_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int8 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 16)) { - int16_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int16 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 32)) { - int32_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int32 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 8)) { - uint8_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint8 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 16)) { - uint16_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint16 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 32)) { - uint32_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint32 tensor:", data, shape); - } else { - MTI_ASSERT(0 && "not implemented!"); - } - }); - -tvm::Array -PrintTVMTensorExtern(const tvm::Tensor& X, - const std::string& name) { - return topi::detail::make_extern( - {X->shape}, - {X->dtype}, - {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.print"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0])}); - }, - name + "_print", "", {}); -} - -tvm::Tensor PrintImmutable(const tvm::Tensor& X) { - auto outputs = PrintTVMTensorExtern(X, X->op->name + "_print"); - return outputs[0]; -} - -void Print(tvm::Tensor& X) { - X = PrintImmutable(X); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/debug/tvm_print.h b/onnxruntime/core/codegen/mti/debug/tvm_print.h deleted file mode 100644 index 91a334785a2a4..0000000000000 --- a/onnxruntime/core/codegen/mti/debug/tvm_print.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array PrintTVMTensorExtern( - const tvm::Tensor& X, - const std::string& name = "PrintTVM2DTensorExtern"); - -tvm::Tensor PrintImmutable(const tvm::Tensor& X); - -void Print(tvm::Tensor& X); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/binary_ops.cc b/onnxruntime/core/codegen/mti/math/binary_ops.cc deleted file mode 100644 index f3048799458f4..0000000000000 --- a/onnxruntime/core/codegen/mti/math/binary_ops.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/binary_ops.h" - -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/cast_ops.h" -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -#define TVM_BINARY_OP1(op, expr) \ - tvm::Tensor op(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } \ - tvm::Tensor op(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } - -#define TVM_BINARY_OP(op, expr) \ - TVM_BINARY_OP1(op, expr) \ - tvm::Tensor op(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } - -TVM_BINARY_OP(Add, lhs + rhs); -TVM_BINARY_OP(Div, lhs / rhs); -TVM_BINARY_OP(Max, maximum(lhs, rhs)); -TVM_BINARY_OP(Min, minimum(lhs, rhs)); -TVM_BINARY_OP(Mul, lhs* rhs); -TVM_BINARY_OP1(PRelu, Relu(lhs) - rhs * Relu(0 - lhs)); -TVM_BINARY_OP(Sub, lhs - rhs); - -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} -tvm::Tensor Equal(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} - -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} -tvm::Tensor Greater(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} - -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} -tvm::Tensor Less(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/binary_ops.h b/onnxruntime/core/codegen/mti/math/binary_ops.h deleted file mode 100644 index dd51ce5e7917d..0000000000000 --- a/onnxruntime/core/codegen/mti/math/binary_ops.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Add(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "add"); -tvm::Tensor Add(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "add"); -tvm::Tensor Add(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "add"); -tvm::Tensor Div(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "div"); -tvm::Tensor Div(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "div"); -tvm::Tensor Div(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "div"); -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "equal"); -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "equal"); -tvm::Tensor Equal(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "equal"); -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "greater"); -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "greater"); -tvm::Tensor Greater(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "greater"); -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "less"); -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "less"); -tvm::Tensor Less(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "less"); -tvm::Tensor Max(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "max"); -tvm::Tensor Max(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "max"); -tvm::Tensor Max(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "max"); -tvm::Tensor Min(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "min"); -tvm::Tensor Min(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "min"); -tvm::Tensor Min(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "min"); -tvm::Tensor Mul(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "mul"); -tvm::Tensor Mul(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "mul"); -tvm::Tensor Mul(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "mul"); -tvm::Tensor PRelu(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "prelu"); -tvm::Tensor PRelu(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "prelu"); -tvm::Tensor Sub(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "sub"); -tvm::Tensor Sub(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "sub"); -tvm::Tensor Sub(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "sub"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/gemm.cc b/onnxruntime/core/codegen/mti/math/gemm.cc deleted file mode 100644 index 7a79513ccaa97..0000000000000 --- a/onnxruntime/core/codegen/mti/math/gemm.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/gemm.h" - -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gemm(const tvm::Tensor& A, const tvm::Tensor& B, const tvm::Tensor& C, - bool trans_A, bool trans_B, float alpha, float beta, - const std::string& name) { - auto A_dot_B = MatMul2D(A, B, trans_A, trans_B, name + "_matmul2d"); - tvm::Expr alphaExpr = tvm::make_const(A->dtype, alpha); - if (beta != 0) { - tvm::Expr betaExpr = tvm::make_const(A->dtype, beta); - return Rename(alphaExpr * A_dot_B + (betaExpr * C), name); - } else { - return Rename(alphaExpr * A_dot_B, name); - } -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/gemm.h b/onnxruntime/core/codegen/mti/math/gemm.h deleted file mode 100644 index 3bb205c13fdc9..0000000000000 --- a/onnxruntime/core/codegen/mti/math/gemm.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gemm(const tvm::Tensor& p_A, const tvm::Tensor& p_B, const tvm::Tensor& p_C, - bool trans_A, bool trans_B, float alpha, float beta, - const std::string& name = "gemm"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/logsoftmax.cc b/onnxruntime/core/codegen/mti/math/logsoftmax.cc deleted file mode 100644 index cd8c2edae6959..0000000000000 --- a/onnxruntime/core/codegen/mti/math/logsoftmax.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/logsoftmax.h" - -#include "core/codegen/mti/tensor/reshape_ops.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor LogSoftmax(const tvm::Tensor& input, int64_t axis, const std::string& name) { - tvm::Tensor flatten_t = Flatten(input, axis, "logsoftmax_flatten"); - return Reshape(topi::nn::log_softmax(flatten_t, name), input->shape, "logsoftmax_reshape"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/logsoftmax.h b/onnxruntime/core/codegen/mti/math/logsoftmax.h deleted file mode 100644 index 606a32806434b..0000000000000 --- a/onnxruntime/core/codegen/mti/math/logsoftmax.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor LogSoftmax(const tvm::Tensor& input, int64_t axis, const std::string& name = "logsoftmax"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.cc b/onnxruntime/core/codegen/mti/math/matmul_ops.cc deleted file mode 100644 index 6ecf2f69a9c25..0000000000000 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.cc +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/matmul_ops.h" - -#include "core/codegen/mti/common.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a, bool trans_b, const std::string& name) { - return topi::matmul(A, B, trans_a, trans_b, name); -} - -/* - * Generic Matrix Multiplication - * - * If both arguments are 2-D, they are multiplied like conventional matrices. - * - * If either argument is N-D and N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly. - * - * If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. - * After matrix multiplication the prepended 1 is removed. - * - * If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. - * After matrix multiplication the appended 1 is removed. - */ -tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string& name) { - int64_t a_rank = static_cast(A->shape.size()); - int64_t b_rank = static_cast(B->shape.size()); - const auto& A_shape = A->shape; - const auto& B_shape = B->shape; - if (a_rank == 2 && b_rank == 2) { - // 2-D X 2-D - return MatMul2D(A, B); - } else if (a_rank == 1 && b_rank == 1) { - // 1-D X 1-D - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[0]), "k"); - - return tvm::compute( - {}, - [&](const tvm::Array& /*indices*/) { - return tvm::sum(A[k] * B[k], {k}); - }, - name); - } else if (a_rank == 1) { - // 1-D X n-D - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[0]), "k"); - - auto l = [&](const tvm::Array& indices) { - auto ndims = indices.size(); - MTI_ASSERT(ndims >= 1); - tvm::Array b_indices; - for (size_t bi = 0; bi < ndims - 1; ++bi) { - b_indices.push_back(indices[bi]); - } - b_indices.push_back(k); - b_indices.push_back(indices[ndims - 1]); - return tvm::sum(A({k}) * B(b_indices), {k}); - }; - return tvm::compute(ConcatShapes(SliceShapeToDimension(B_shape, -2), SliceShapeFromDimension(B_shape, -1)), l, name); - } else if (b_rank == 1) { - // n-D X 1-D - auto k = tvm::reduce_axis(tvm::Range(0, B_shape[0]), "k"); - - auto l = [&](const tvm::Array& indices) { - tvm::Array a_indices(indices.begin(), indices.end()); - a_indices.push_back(k); - return tvm::sum(A(a_indices) * B({k}), {k}); - }; - return tvm::compute(SliceShapeToDimension(A->shape, -1), l, name); - } else { - // n-D X m-D - MTI_ASSERT(a_rank >= 2 && b_rank >= 2); - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[a_rank - 1]), "k"); - - auto l = [&](const tvm::Array& indices) { - auto ndims = static_cast(indices.size()); - MTI_ASSERT(ndims > 2); - tvm::Array a_indices, b_indices; - - // handle broadcasting - int i = 0, a_idx = 0, b_idx = 0; - bool a_greater = a_rank > b_rank; - for (; i < std::abs(a_rank - b_rank); ++i) { - if (a_greater) { - a_indices.push_back(indices[i]); - a_idx++; - } else { - b_indices.push_back(indices[i]); - b_idx++; - } - } - for (; i < ndims - 2; ++i, ++a_idx, ++b_idx) { - auto tp = indices[i].type(); - if (IsOne(A_shape, a_idx)) { - a_indices.push_back(tvm::make_zero(tp)); - b_indices.push_back(indices[i]); - } else if (IsOne(B_shape, b_idx)) { - b_indices.push_back(tvm::make_zero(tp)); - a_indices.push_back(indices[i]); - } else { - a_indices.push_back(indices[i]); - b_indices.push_back(indices[i]); - } - } - - MTI_ASSERT(a_idx == a_rank - 2 && b_idx == b_rank - 2); - a_indices.push_back(indices[ndims - 2]); - a_indices.push_back(k); - - b_indices.push_back(k); - b_indices.push_back(indices[ndims - 1]); - - return tvm::sum(A(a_indices) * B(b_indices), {k}); - }; - - return tvm::compute(ComputeMatMulShape(A_shape, B_shape), l, name); - } -} - -tvm::Array -ComputeMatMulShape( - const tvm::Array& A_shape, - const tvm::Array& B_shape, - bool trans_a, - bool trans_b) { - auto a_rank = A_shape.size(); - auto b_rank = B_shape.size(); - tvm::Array output_shape; - int64_t output_rank = std::max(a_rank, b_rank); - MTI_ASSERT(a_rank > 0 && b_rank > 0); - if (a_rank == 1 && b_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - // reduction, output shape is empty - } else if (a_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - output_shape = SliceShapeToDimension(B_shape, b_rank - 2); - output_shape.push_back(B_shape[b_rank - 1]); - } else if (b_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - output_shape = SliceShapeToDimension(A_shape, a_rank - 1); - } else { - for (int64_t i = 0; i < output_rank - 2; i++) { - tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1); - bool broadcasted = - BroadcastDim(A_shape, i, output_rank, broadcasted_dim) && - BroadcastDim(B_shape, i, output_rank, broadcasted_dim); - MTI_ASSERT(broadcasted); - output_shape.push_back(broadcasted_dim); - } - output_shape.push_back(A_shape[a_rank - (trans_a ? 1 : 2)]); - output_shape.push_back(B_shape[b_rank - (trans_b ? 2 : 1)]); - } - return output_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.h b/onnxruntime/core/codegen/mti/math/matmul_ops.h deleted file mode 100644 index ab9986132d34a..0000000000000 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array -ComputeMatMulShape( - const tvm::Array& A_shape, - const tvm::Array& B_shape, - bool trans_a = false, - bool trans_b = false); - -tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, const std::string& name = "matmul2d"); - -tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string& name = "matmul"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/reduce_ops.cc b/onnxruntime/core/codegen/mti/math/reduce_ops.cc deleted file mode 100644 index 7d179e2b04316..0000000000000 --- a/onnxruntime/core/codegen/mti/math/reduce_ops.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/reduce_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor ArgMax(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name) { - return Rename(topi::argmax(X, ToTvmArrayInt({axis}), keep_dims), name); -} - -tvm::Tensor ArgMin(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name) { - return Rename(topi::argmin(X, ToTvmArrayInt({axis}), keep_dims), name); -} - -tvm::Tensor ReduceL1(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return ReduceSum(Abs(X), axes, keep_dims, name); -} - -tvm::Tensor ReduceL2(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Sqrt(ReduceSumSquare(X, axes, keep_dims), name); -} - -tvm::Tensor ReduceLogSum(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Log(ReduceSum(X, axes, keep_dims), name); -} - -tvm::Tensor ReduceLogSumExp(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - tvm::Tensor reduce_max = ReduceMax(X, axes, true); - tvm::Tensor exp_delta = Exp(Sub(X, reduce_max)); - tvm::Tensor reduce_max_keep_dims = ReduceMax(X, axes, keep_dims); - return Add(ReduceLogSum(exp_delta, axes, keep_dims), reduce_max_keep_dims, name); -} - -tvm::Tensor ReduceMax(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::max(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceMean(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - tvm::Tensor reduce_sum = ReduceSum(X, axes, keep_dims); - tvm::Expr count = tvm::make_const(reduce_sum->dtype, 1.0f); - if (axes.empty()) { - for (const auto& dim : X->shape) - count = count * dim; - } else { - for (int64_t axis : axes) { - int64_t i = HandleNegativeAxis(axis, X->shape.size()); - count = count * X->shape[i]; - } - } - return tvm::compute( - reduce_sum->shape, - [&](const tvm::Array& i) { - return reduce_sum(i) / count; - }, - name); -} - -tvm::Tensor ReduceMin(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::min(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceProd(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - auto prod = [](tvm::Expr source, tvm::Array rdom) { - tvm::Var x("x", source.type()), y("y", source.type()); - tvm::Expr Rename_element = tvm::make_const(source.type(), 1.0f); - tvm::ir::CommReducer combiner = - tvm::ir::CommReducerNode::make({x}, {y}, {x * y}, {Rename_element}); - return tvm::ir::Reduce::make(combiner, {source}, rdom, tvm::make_const(tvm::Bool(1), true), 0); - }; - - return Rename(topi::CommReduce(X, ToTvmArrayInt(axes), prod, keep_dims, true), name); -} - -tvm::Tensor ReduceSum(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::sum(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceSumSquare(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::sum(Mul(X, X), ToTvmArrayInt(axes), keep_dims), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/reduce_ops.h b/onnxruntime/core/codegen/mti/math/reduce_ops.h deleted file mode 100644 index f782df5e6515f..0000000000000 --- a/onnxruntime/core/codegen/mti/math/reduce_ops.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor ArgMax(const tvm::Tensor& X, - int64_t axis, - bool keep_dims, - const std::string& name = "argmax"); - -tvm::Tensor ArgMin(const tvm::Tensor& X, - int64_t axis, - bool keep_dims, - const std::string& name = "argmin"); - -tvm::Tensor ReduceL1(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_l1"); - -tvm::Tensor ReduceL2(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_l2"); - -tvm::Tensor ReduceLogSum(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_log_sum"); - -tvm::Tensor ReduceLogSumExp(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "argmareduce_log_sum_exp"); - -tvm::Tensor ReduceMax(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_max"); - -tvm::Tensor ReduceMean(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_mean"); - -tvm::Tensor ReduceMin(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_min"); - -tvm::Tensor ReduceProd(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_prod"); - -tvm::Tensor ReduceSum(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_sum"); - -tvm::Tensor ReduceSumSquare(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_sum_square"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/softmax.cc b/onnxruntime/core/codegen/mti/math/softmax.cc deleted file mode 100644 index d7404137bb873..0000000000000 --- a/onnxruntime/core/codegen/mti/math/softmax.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/softmax.h" - -#include "core/codegen/mti/tensor/reshape_ops.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Softmax(const tvm::Tensor& input, int64_t axis, const std::string& name) { - tvm::Tensor flatten_t = Flatten(input, axis, "softmax_flatten"); - return Reshape(topi::nn::softmax(flatten_t, 1, name), input->shape, "softmax_reshape"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/softmax.h b/onnxruntime/core/codegen/mti/math/softmax.h deleted file mode 100644 index fb16fbaeb56a2..0000000000000 --- a/onnxruntime/core/codegen/mti/math/softmax.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Softmax(const tvm::Tensor& input, int64_t axis, const std::string& name = "softmax"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/unary_ops.cc b/onnxruntime/core/codegen/mti/math/unary_ops.cc deleted file mode 100644 index ae732ea33e670..0000000000000 --- a/onnxruntime/core/codegen/mti/math/unary_ops.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/unary_ops.h" - -#include "core/codegen/common/settings.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include -#include -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Abs(const tvm::Tensor& X, const std::string& name) { - return abs(X, name); -} - -tvm::Tensor Affine(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * X + betaExpr, name); -} - -tvm::Tensor Ceil(const tvm::Tensor& X, const std::string& name) { - return topi::ceil(X, name); -} - -tvm::Tensor Clip(const tvm::Tensor& X, tvm::Expr min_value, tvm::Expr max_value, const std::string& name) { - auto Y = tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::min(tvm::max(X(indices), min_value), max_value); - }, - name); - return Y; -} - -tvm::Tensor Elu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return Rename(Relu(X) - alphaExpr * Relu(1 - Exp(X)), name); -} - -tvm::Tensor Exp(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::exp(X(indices)); - }, - name); -} - -tvm::Tensor Floor(const tvm::Tensor& X, const std::string& name) { - return topi::floor(X, name); -} - -tvm::Tensor HardSigmoid(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return maximum(0, minimum(1, alphaExpr * X + betaExpr), name); -} - -tvm::Tensor LeakyRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return Rename(Relu(X) - alphaExpr * Relu(0 - X), name); -} - -tvm::Tensor Log(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::log(X(indices)); - }, - name); -} - -tvm::Tensor Neg(const tvm::Tensor& X, const std::string& name) { - return negative(X, name); -} - -tvm::Tensor ParametricSoftplus(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * Softplus(betaExpr * X), name); -} - -tvm::Tensor Reciprocal(const tvm::Tensor& X, const std::string& name) { - return Rename(1 / X, name); -} - -tvm::Tensor Relu(const tvm::Tensor& X, const std::string& name) { - return maximum(X, 0, name); -} - -tvm::Tensor ScaledTanh(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * Tanh(betaExpr * X), name); -} - -tvm::Tensor Selu(const tvm::Tensor& X, float alpha, float gamma, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr gammaExpr = tvm::make_const(X->dtype, gamma); - return Rename(gammaExpr * (-alphaExpr * Relu(1 - Exp(X)) + Relu(X)), name); -} - -tvm::Tensor Sigmoid(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::ir::Select::make(X(indices) > 0, - 1 / (1 + tvm::exp(-X(indices))), - tvm::exp(X(indices)) / (tvm::exp(X(indices)) + 1)); - }, - name); -} - -tvm::Tensor SignNoZero(const tvm::Tensor& X, const std::string& name) { - return Rename(greater_equal(X, 0) * 2 - 1, name); -} - -tvm::Tensor Softplus(const tvm::Tensor& X, const std::string& name) { - return Rename(Log(1 + Exp(Neg(Abs(X)))) + Relu(X), name); -} - -tvm::Tensor Softsign(const tvm::Tensor& X, const std::string& name) { - return Rename(X / (1 + Abs(X)), name); -} - -tvm::Tensor Sqrt(const tvm::Tensor& X, const std::string& name) { - return sqrt(X, name); -} - -tvm::Tensor Tanh(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::ir::Select::make(X(indices) < 0, - (tvm::exp(2 * X(indices)) - 1) / (tvm::exp(2 * X(indices)) + 1), - (1 - tvm::exp(-2 * X(indices))) / (1 + tvm::exp(-2 * X(indices)))); - }, - name); -} - -tvm::Tensor ThresholdedRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return topi::where(greater(X, alphaExpr), X, topi::full_like(X, tvm::make_zero(X->dtype)), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/unary_ops.h b/onnxruntime/core/codegen/mti/math/unary_ops.h deleted file mode 100644 index aeb336262e547..0000000000000 --- a/onnxruntime/core/codegen/mti/math/unary_ops.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Abs(const tvm::Tensor& X, const std::string& name = "abs"); -tvm::Tensor Affine(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "affine"); -tvm::Tensor Ceil(const tvm::Tensor& X, const std::string& name = "ceil"); -tvm::Tensor Clip(const tvm::Tensor& X, tvm::Expr min_value, tvm::Expr max_value, const std::string& name = "clip"); -tvm::Tensor Elu(const tvm::Tensor& X, float alpha, const std::string& name = "elu"); -tvm::Tensor Exp(const tvm::Tensor& X, const std::string& name = "exp"); -tvm::Tensor Floor(const tvm::Tensor& X, const std::string& name = "floor"); -tvm::Tensor HardSigmoid(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "hard_sigmoid"); -tvm::Tensor LeakyRelu(const tvm::Tensor& X, float alpha, const std::string& name = "leaky_relu"); -tvm::Tensor Log(const tvm::Tensor& X, const std::string& name = "log"); -tvm::Tensor Neg(const tvm::Tensor& X, const std::string& name = "neg"); -tvm::Tensor ParametricSoftplus(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "parametric_softplus"); -tvm::Tensor Reciprocal(const tvm::Tensor& X, const std::string& name = "reciprocal"); -tvm::Tensor Relu(const tvm::Tensor& X, const std::string& name = "relu"); -tvm::Tensor ScaledTanh(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "scaled_tanh"); -tvm::Tensor Selu(const tvm::Tensor& X, float alpha, float gamma, const std::string& name = "selu"); -tvm::Tensor Sigmoid(const tvm::Tensor& X, const std::string& name = "sigmoid"); -tvm::Tensor SignNoZero(const tvm::Tensor& X, const std::string& name = "sign_no_zero"); -tvm::Tensor Softplus(const tvm::Tensor& X, const std::string& name = "softplus"); -tvm::Tensor Softsign(const tvm::Tensor& X, const std::string& name = "softsign"); -tvm::Tensor Sqrt(const tvm::Tensor& X, const std::string& name = "sqrt"); -tvm::Tensor Tanh(const tvm::Tensor& X, const std::string& name = "tanh"); -tvm::Tensor ThresholdedRelu(const tvm::Tensor& X, float alpha, const std::string& name = "thresholded_relu"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc deleted file mode 100644 index 8e73629c05614..0000000000000 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/mti_tvm_utils.h" - -#include "core/codegen/common/settings.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array ToTvmArray(gsl::span shape) { - tvm::Array arr; - for (size_t i = 0; i < shape.size(); ++i) { - arr.push_back(tvm::Expr(static_cast(shape[i]))); - } - return arr; -} - -tvm::Array ToTvmArrayInt(gsl::span shape) { - tvm::Array arr; - for (size_t i = 0; i < shape.size(); ++i) { - arr.push_back(shape[i]); - } - return arr; -} - -tvm::Expr SizeToDimension(const tvm::Array& shape, int64_t axis) { - tvm::Expr size(1); - auto rank = shape.size(); - if (static_cast(axis) != rank) { - axis = HandleNegativeAxis(axis, rank); - } - for (size_t d = 0; d < std::min(rank, static_cast(axis)); ++d) - size = tvm::ir::Simplify(size * shape[d]); - return size; -} - -tvm::Expr SizeFromDimension(const tvm::Array& shape, int64_t axis) { - tvm::Expr size(1); - auto rank = shape.size(); - if (static_cast(axis) != rank) { - axis = HandleNegativeAxis(axis, rank); - } - for (size_t d = static_cast(axis); d < rank; ++d) - size = tvm::ir::Simplify(size * shape[d]); - return size; -} - -tvm::Expr RoundUp(tvm::Expr value, tvm::Expr alignment) { - return tvm::ir::Simplify((value + alignment - 1) / alignment * alignment); -} - -tvm::Array ConcatShapes( - const tvm::Array& shape1, - const tvm::Array& shape2) { - tvm::Array result; - for (size_t i = 0; i < shape1.size(); i++) - result.push_back(shape1[i]); - for (size_t i = 0; i < shape2.size(); i++) - result.push_back(shape2[i]); - return result; -} - -tvm::Tensor Rename(tvm::Tensor X, const std::string& name) { - const_cast(X->op->name) = name; - return X; -} - -tvm::Array SliceShape(const tvm::Array& shape, const std::vector& axes) { - tvm::Array new_shape; - for (auto axis : axes) { - CHECK(axis < static_cast(shape.size())); - new_shape.push_back(shape[axis]); - } - return new_shape; -} - -tvm::Array SliceShapeFromDimension(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - std::vector axes; - for (auto i = axis; i < rank; ++i) - axes.push_back(i); - return SliceShape(shape, axes); -} - -tvm::Array SliceShapeToDimension(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - std::vector axes; - for (auto i = 0; i < axis; ++i) - axes.push_back(i); - return SliceShape(shape, axes); -} - -bool IsOne(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - const auto& dim = shape[axis]; - auto* p = tvm::as_const_int(dim); - return p != nullptr && *p == 1; -} - -tvm::Tensor Promote(const tvm::Expr& expr, const tvm::Array& shape, const std::string& name) { - return tvm::compute( - shape, - [&](const tvm::Array&) { - return expr; - }, - name); -} - -void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module) { - const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); - if (!settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpModule)) - return; - - // ISSUE: note that all option values are converted to lower case. It doesn't cause - // any issue currently, because all supported formats (i.e. file exts) are of lower case. - // Just keep in mind that we might have issue if somehow we started to support dump - // formats with upper case, although it's quite unlikely. - std::string format = settings.GetOptionValue(codegen::CodeGenSettings::kCodeGenDumpModule); - std::string module_filename = filename + "." + format; - module->SaveToFile(module_filename, format); -} - -tvm::Tensor MakeZeroTensor(const tvm::Array& shape, - HalideIR::Type type, - const std::string& name) { - auto l = [&](const tvm::Array& /*indices*/) { - return tvm::make_zero(type); - }; - return tvm::compute(shape, l, name); -} - -bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_rank, tvm::Expr& dim) { - if (i >= output_rank - shape.size()) { - auto new_dim = shape[shape.size() - output_rank + i]; - if (tvm::ir::Equal(new_dim, dim)) - return true; - - const int64_t* p_new = tvm::as_const_int(new_dim); - if (p_new != nullptr && *p_new == 1) { - return true; - } else { - const int64_t* p_old = tvm::as_const_int(dim); - if (p_old != nullptr && *p_old == 1) { - dim = new_dim; - return true; - } - } - return false; - } - // auto broadcast to outer dims - return true; -} - -tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name) { - // note that currently TVM StorageFlatten creates strides like max(symbolic_dim, 1) - // which is not zero when checking symbolic_dim - max(symbolic_dim, 1) - // then triggers error like: Trying to bind compact buffer to strided one - // here's a workaround to reshape inputs to avoid that - tvm::Array fixed_inputs; - for (size_t idx_input = 0; idx_input < inputs.size(); ++idx_input) { - const auto& input = inputs[idx_input]; - tvm::Array fixed_shape; - if (input->shape.size() > 0) { - // stride compute does not use dim 0, so directly push to fixed_shape - fixed_shape.push_back(input->shape[0]); - bool need_fix = false; - for (size_t idx_dim = 1; idx_dim < input->shape.size(); ++idx_dim) { - const auto& dim = input->shape[idx_dim]; - if (tvm::as_const_int(dim) == nullptr) { - fixed_shape.push_back(tvm::max(dim, tvm::make_const(HalideIR::Int(32), 1))); - need_fix = true; - } else { - fixed_shape.push_back(dim); - } - } - if (need_fix) { - fixed_inputs.push_back(tvm_codegen::Reshape(input, fixed_shape, name + "_" + std::to_string(idx_input))); - continue; - } - } - // no fix needed - fixed_inputs.push_back(input); - } - return fixed_inputs; -} - -// Make sure idx is clamped in the range of [-bound, bound - 1] -tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound) { - // when idx >= 0, we take tvm::max(..., 0), because (idx < 0) is 0 - // when idx < 0, we take bound + tvm::max(...), because tvm::max(idx, 0) is 0 - return tvm::max(tvm::min(idx, bound - 1), 0) + - (idx < 0) * (bound + tvm::max(idx, -bound)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h deleted file mode 100644 index c2a14106c1686..0000000000000 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include "core/codegen/mti/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array ToTvmArray(gsl::span shape); - -tvm::Array ToTvmArrayInt(gsl::span shape); - -// Helper function to compute sub shape size to axis (not included) -tvm::Expr SizeToDimension(const tvm::Array& shape, int64_t axis); - -// Helper function to compute sub shape size from axis (included) -tvm::Expr SizeFromDimension(const tvm::Array& shape, int64_t axis); - -// Helper function to align -tvm::Expr RoundUp(tvm::Expr value, tvm::Expr alignment); - -tvm::Array ConcatShapes( - const tvm::Array& shape1, - const tvm::Array& shape2); - -// Helper function to rename tvm::Tensor -tvm::Tensor Rename(tvm::Tensor X, const std::string& name); - -// Helper function to slice TVM shape -tvm::Array SliceShape(const tvm::Array& shape, const std::vector& axes); - -// Helper function to slice TVM shape from axis (inclusive). -// Basically, this function returns the shape of [axis, shape.size()-1] -tvm::Array SliceShapeFromDimension(const tvm::Array& shape, int64_t axis); - -// this function returns the shape of [0, axis-1] -tvm::Array SliceShapeToDimension(const tvm::Array& shape, int64_t axis); - -// check if dimension is 1 -bool IsOne(const tvm::Array& shape, int64_t axis); - -// Helper function to convert tvm::Expr to tvm::Tensor -tvm::Tensor Promote(const tvm::Expr& expr, - const tvm::Array& shape, - const std::string& name = "PromoteExpr"); - -tvm::Tensor MakeZeroTensor(const tvm::Array& shape, HalideIR::Type type, const std::string& name); - -void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module); - -bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_rank, tvm::Expr& dim); - -inline int64_t HandleNegativeAxis(int64_t axis, int64_t rank) { - MTI_ASSERT(axis >= -rank && axis <= rank - 1); - return axis = axis < 0 ? (axis + rank) : axis; -} - -// Make sure idx is clamped in the range of [-bound, bound - 1] -tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound); - -// Helper function to workaround tvm ExternOp issue when input has symbolic dimensions -tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name = "make_inputs_for_extern"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/conv_ops.cc b/onnxruntime/core/codegen/mti/nn/conv_ops.cc deleted file mode 100644 index e2d4acc8843ad..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/conv_ops.cc +++ /dev/null @@ -1,193 +0,0 @@ -#include "core/codegen/mti/nn/conv_ops.h" - -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/tensor/pad_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/mti/tensor/transpose.h" - -namespace onnxruntime { -namespace tvm_codegen { - -static tvm::Tensor PadTensor1D(const tvm::Tensor& input, - const tvm::Array& padding, - size_t width_axis, - const std::string& name) { - auto pad_left = padding[0]; - auto pad_right = padding[1]; - - tvm::Array pad_before(std::vector(input->shape.size(), 0)); - pad_before.Set(width_axis, pad_left); - tvm::Array pad_after(std::vector(input->shape.size(), 0)); - pad_after.Set(width_axis, pad_right); - - const int64_t* padding_w0 = tvm::as_const_int(pad_left); - const int64_t* padding_w1 = tvm::as_const_int(pad_right); - - const bool do_pad = ((padding_w0 != nullptr && *padding_w0) || - (padding_w1 != nullptr && *padding_w1)); - - return do_pad ? Pad(input, pad_before, pad_after, - 0, "constant", name + "_input_padded") - : input; -} - -tvm::Tensor Conv1D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t channel_axis = 1; - size_t width_axis = 2; - - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor1D(input, padding, width_axis, name); - auto rc = tvm::reduce_axis((tvm::Range(0, filter->shape[1])), "rc"); - auto rx = tvm::reduce_axis((tvm::Range(0, filter->shape[2])), "rx"); - - return tvm::compute( - out_shape, - [&](const tvm::Array& output) { - tvm::Array indices; - for (const tvm::Var& var : output) { - indices.push_back(var); - } - indices.Set(channel_axis, rc); - indices.Set(width_axis, output[width_axis] * stride_width + rx); - - return tvm::sum(input_padded(indices) * filter({output[1], rc, rx}), - {rc, rx}); - }, - name); -} - -tvm::Tensor Conv2D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - return Conv2D_native(input, filter, output_shape, stride, padding); -} - -static tvm::Tensor PadTensor2D(const tvm::Tensor& input, - const tvm::Array& padding, - size_t height_axis, - size_t width_axis, - const std::string& name) { - auto pad_top = padding[0]; - auto pad_left = padding[1]; - auto pad_bottom = padding[2]; - auto pad_right = padding[3]; - - tvm::Array pad_before(std::vector(input->shape.size(), 0)); - pad_before.Set(height_axis, pad_top); - pad_before.Set(width_axis, pad_left); - - tvm::Array pad_after(std::vector(input->shape.size(), 0)); - pad_after.Set(height_axis, pad_bottom); - pad_after.Set(width_axis, pad_right); - - const int64_t* padding_h0 = tvm::as_const_int(pad_top); - const int64_t* padding_w0 = tvm::as_const_int(pad_left); - const int64_t* padding_h1 = tvm::as_const_int(pad_bottom); - const int64_t* padding_w1 = tvm::as_const_int(pad_right); - - const bool do_pad = ((padding_h0 != nullptr && *padding_h0) || - (padding_w0 != nullptr && *padding_w0)) || - ((padding_h1 != nullptr && *padding_h1) || - (padding_w1 != nullptr && *padding_w1)); - - return do_pad ? Pad(input, pad_before, pad_after, - 0, "constant", name + "_input_padded") - : input; -} - -tvm::Tensor Conv2D_native(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t channel_axis = 1; - size_t height_axis = 2; - size_t width_axis = 3; - - auto stride_height = stride[height_axis - 2]; - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor2D(input, padding, height_axis, width_axis, name); - - auto rc = tvm::reduce_axis((tvm::Range(0, filter->shape[1])), "rc"); - auto ry = tvm::reduce_axis((tvm::Range(0, filter->shape[2])), "ry"); - auto rx = tvm::reduce_axis((tvm::Range(0, filter->shape[3])), "rx"); - - return tvm::compute( - out_shape, - [&](const tvm::Array& output) { - tvm::Array indices; - for (const tvm::Var& var : output) { - indices.push_back(var); - } - indices.Set(channel_axis, rc); - indices.Set(height_axis, output[height_axis] * stride_height + ry); - indices.Set(width_axis, output[width_axis] * stride_width + rx); - - return tvm::sum(input_padded(indices) * filter({output[1], rc, ry, rx}), - {rc, ry, rx}); - }, - name); -} - -tvm::Tensor Conv2D_gemm(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t height_axis = 2; - size_t width_axis = 3; - - auto stride_height = stride[height_axis - 2]; - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor2D(input, padding, height_axis, width_axis, name); - - tvm::Array img_col_tmp(std::vector(6, 0)); - img_col_tmp.Set(0, out_shape[0]); - img_col_tmp.Set(1, out_shape[2]); - img_col_tmp.Set(2, out_shape[3]); - img_col_tmp.Set(3, filter->shape[1]); - img_col_tmp.Set(4, filter->shape[2]); - img_col_tmp.Set(5, filter->shape[3]); - - auto img_col = tvm::compute( - img_col_tmp, - [&](const tvm::Array& output) { - tvm::Array indices; - indices.push_back(output[0]); - indices.push_back(output[3]); - indices.push_back(output[1] * stride_height + output[4]); - indices.push_back(output[2] * stride_width + output[5]); - return input_padded(indices); - }, - name); - - tvm::Array input_col_shape(std::vector(2, 0)); - input_col_shape.Set(0, img_col_tmp[1] * img_col_tmp[2]); - input_col_shape.Set(1, img_col_tmp[3] * img_col_tmp[4] * img_col_tmp[5]); - auto input_col = Reshape(img_col, input_col_shape); - - tvm::Array filter_row_shape(std::vector(2, 0)); - filter_row_shape.Set(0, filter->shape[0]); - filter_row_shape.Set(1, filter->shape[1] * filter->shape[2] * filter->shape[3]); - auto filter_row = Reshape(filter, filter_row_shape, name); - - auto Y = MatMul2D(input_col, filter_row, false, true, name); - auto Y_T = Transpose(Y, /*axes=*/{}, name); - return Reshape(Y_T, out_shape, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/conv_ops.h b/onnxruntime/core/codegen/mti/nn/conv_ops.h deleted file mode 100644 index 1396c216865a7..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/conv_ops.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Conv1D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv1d"); - -tvm::Tensor Conv2D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d"); - -tvm::Tensor Conv2D_native(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d_native"); - -tvm::Tensor Conv2D_gemm(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d_gemm"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/lstm.cc b/onnxruntime/core/codegen/mti/nn/lstm.cc deleted file mode 100644 index 1148b0924e869..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/lstm.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/nn/lstm.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/math/reduce_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/mti/tensor/split.h" - -namespace onnxruntime { -namespace tvm_codegen { - -/* -`X` - input tensor -`i` - input gate -`o` - output gate -`f` - forget gate -`c` - cell gate -`t` - time step (t-1 means previous time step) - -`W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates -`R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates -`Wb[iofc]` - W bias vectors for input, output, forget, and cell gates -`Rb[iofc]` - R bias vectors for input, output, forget, and cell gates -`P[iof]` - P peephole weight vector for input, output, and forget gates -`WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates -`RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates -`WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates -`RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates -`PB[iof]` - P peephole weight vector for backward input, output, and forget gates - -`H` - Hidden state -`num_directions` - 2 if direction == bidirectional else 1 - -Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - Ct = ft (.) Ct-1 + it (.) ct - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - Ht = ot (.) h(Ct) -*/ - -void LSTM_cell( - const LSTMAttributes& lstm_attrs, - const tvm::Tensor& X, - const tvm::Tensor& W, - const tvm::Tensor& R, - const tvm::Tensor& B, - bool has_B, - const tvm::Tensor& prev_H, - const tvm::Tensor& prev_C, - const tvm::Tensor& P, - bool has_P, - tvm::Tensor& Y_h, - tvm::Tensor& Y_c) { - // Input projection: Xt*(W[iofc]^T) for forward direction or Xt*(WB[iofc]^T) for reverse direction - // (batch_size, input_size) * trans(4 * hidden_size, input_size) => (batch_size, 4 * hidden_size) - tvm::Tensor input_proj = MatMul2D(X, W, /*trans_a*/ false, /*trans_b*/ true); - - // Hidden projection: Ht-1*(R[iofc]^T) for forward direction or Ht-1*(RB[iofc]^T) for reverse direction - // (batch_size, hidden_size) * trans(4 * hidden_size, hidden_size) => (batch_size, 4 * hidden_size) - tvm::Tensor hidden_proj = MatMul2D(prev_H, R, /*trans_a*/ false, /*trans_b*/ true); - - // (batch_size, 4 * hidden_size) - tvm::Tensor sum_proj = Add(input_proj, hidden_proj); - - // Concatenation of [Wb[iofc], Rb[iofc]] or [WBb[iofc], RBb[iofc]] - if (has_B) { - // (8 * hidden_size) -> (2, 4 * hidden_size) -> (1, 4 * hidden_size), should be done in const folding - tvm::Tensor reduce_B = - ReduceSum(Reshape(B, {2, 4 * static_cast(lstm_attrs.hidden_size)}), {0}, /*keep_dims*/ true); - // (batch_size, 4 * hidden_size) via broadcasting reduce_B - sum_proj = Add(sum_proj, reduce_B); - } - - std::vector iofc_sum_split_sizes(4, lstm_attrs.hidden_size); - // Split sum_proj into iofc, where each gate proj is of (batch_size, hidden_size) - tvm::Array iofc_sum_projs = Split(sum_proj, ToTvmArray(iofc_sum_split_sizes), /*axis*/ 1); - MTI_ASSERT(iofc_sum_projs.size() == 4); - tvm::Tensor i_proj = iofc_sum_projs[0], - o_proj = iofc_sum_projs[1], - f_proj = iofc_sum_projs[2], - c_proj = iofc_sum_projs[3]; - - tvm::Tensor P_i, P_o, P_f; - if (has_P) { - std::vector iof_p_split_sizes(3, lstm_attrs.hidden_size); - // Split P into P_i, P_o, P_f, in const pre-processing (P_i, P_f might be merged?) - // where each P_[iof] has the shape of (hidden_size) - tvm::Array iof_P_projs = Split(P, ToTvmArray(iof_p_split_sizes), /*axis*/ 0); - MTI_ASSERT(iof_P_projs.size() == 3); - P_i = iof_P_projs[0], - P_o = iof_P_projs[1], - P_f = iof_P_projs[2]; - - // (batch_size, hidden_size) via broadcasting P_[if] - i_proj = Add(i_proj, Mul(P_i, prev_C)); - f_proj = Add(f_proj, Mul(P_f, prev_C)); - } - - // TODO: handle more general cases for activations f, h, g and activation_alpha and - // activation_beta. We may consider to move some code such as ActivationInfo from deep_cpu_lstm - // into a common header file, because the code can be used here. - - // Note that by default f = Sigmoid, g = Tanh, h = Tanh - - // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - // shape: (batch_size, hidden_size) - tvm::Tensor i_t = Sigmoid(i_proj); - // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - // shape: (batch_size, hidden_size) - tvm::Tensor f_t = Sigmoid(f_proj); - // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - // shape: (batch_size, hidden_size) - tvm::Tensor c_t = Tanh(c_proj); - - // Ct = ft (.) Ct-1 + it (.) ct - // shape: (batch_size, hidden_size) - Y_c = Add(Mul(f_t, prev_C), Mul(i_t, c_t), Y_c->op->name); - - // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - // shape: (batch_size, hidden_size) - if (has_P) { - o_proj = Add(o_proj, Mul(P_o, Y_c)); - } - // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - // shape: (batch_size, hidden_size) - o_proj = Sigmoid(o_proj); - // Ht = ot (.) h(Ct) - // shape: (batch_size, hidden_size) - Y_h = Mul(o_proj, Tanh(Y_c), Y_h->op->name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/lstm.h b/onnxruntime/core/codegen/mti/nn/lstm.h deleted file mode 100644 index 851fa880c4427..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/lstm.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -// A bubble now. But don't remove it -// TODO: refactor the LSTMcell building to a tvm function -// and move it here - -namespace onnxruntime { -namespace tvm_codegen { - -struct LSTMAttributes { - LSTMAttributes(int64_t hidden_size_p) : hidden_size(hidden_size_p) {} - int64_t hidden_size; -}; - -void LSTM_cell( - const LSTMAttributes& lstm_attrs, - const tvm::Tensor& X, - const tvm::Tensor& W, - const tvm::Tensor& R, - const tvm::Tensor& B, - bool has_B, - const tvm::Tensor& prev_H, - const tvm::Tensor& prev_C, - const tvm::Tensor& P, - bool has_P, - tvm::Tensor& Y_h, - tvm::Tensor& Y_c); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/pool_ops.cc b/onnxruntime/core/codegen/mti/nn/pool_ops.cc deleted file mode 100644 index 868a14748cabc..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/pool_ops.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/nn/pool_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/mlas/inc/mlas.h" -#include "core/providers/cpu/nn/pool_attributes.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// TODO: topi only support 2d-pool, MaxPool1d and MaxPool3d will need to be added if necessary. -// only support version < 8 for topi doesn't come with implementation to output index tensor -tvm::Tensor MaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::pool(input, - ToTvmArray(pool_attrs.kernel_shape), - ToTvmArray(pool_attrs.strides), - ToTvmArray(pool_attrs.pads), - /*pool_type*/ topi::nn::kMaxPool, - /*ceil_mode*/ false, - /*layout*/ pool_attrs.storage_order == 0 ? "NCWH" : "NCHW", - pool_attrs.count_include_pad); -} - -tvm::Tensor AveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::pool(input, - ToTvmArray(pool_attrs.kernel_shape), - ToTvmArray(pool_attrs.strides), - ToTvmArray(pool_attrs.pads), - /*pool_type*/ topi::nn::kAvgPool, - /*ceil_mode*/ false, - /*layout*/ "NCHW", - pool_attrs.count_include_pad); -} - -tvm::Tensor GlobalMaxPool(const tvm::Tensor& input, - const PoolAttributes& /*pool_attrs*/, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::global_pool(input, - /*pool_type*/ topi::nn::kMaxPool, - /*layout*/ "NCHW"); -} - -tvm::Tensor GlobalAveragePool(const tvm::Tensor& input, - const PoolAttributes& /*pool_attrs*/, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::global_pool(input, - /*pool_type*/ topi::nn::kAvgPool, - /*layout*/ "NCHW"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/pool_ops.h b/onnxruntime/core/codegen/mti/nn/pool_ops.h deleted file mode 100644 index d381f9ddff859..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/pool_ops.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { - -// Forward declaration -struct PoolAttributes; - -namespace tvm_codegen { - -tvm::Tensor MaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "max_pool"); - -tvm::Tensor AveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "average_pool"); - -tvm::Tensor GlobalMaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "global_max_pool"); - -tvm::Tensor GlobalAveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "global_average_pool"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/cast_ops.cc b/onnxruntime/core/codegen/mti/tensor/cast_ops.cc deleted file mode 100644 index a8fc86488d82b..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/cast_ops.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/cast_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Cast(const tvm::Tensor& X, tvm::Type type, const std::string& name) { - return topi::cast(X, type, name); -} - -// handle cases where bool is reprented as uint8 (e.g. in ONNX). -tvm::Tensor CastToUInt8Bool(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - auto val = X(indices); - // A special cast from float16 to bool, first cast up to float32, - // to workaround a float16 bug in many TVM backends. - // Intel Skylake is one of them. https://github.com/dmlc/tvm/issues/2959 - // TODO: remove it, after TVM is fixed - if (X->dtype == HalideIR::Float(16)) - val = tvm::cast(HalideIR::Float(32), val); - return tvm::ir::Select::make(topi::equal(val, tvm::make_zero(val.type())), - tvm::make_zero(HalideIR::UInt(8)), - tvm::make_const(HalideIR::UInt(8), 1)); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/cast_ops.h b/onnxruntime/core/codegen/mti/tensor/cast_ops.h deleted file mode 100644 index 02f6f9cb1fde7..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/cast_ops.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Cast(const tvm::Tensor& X, tvm::Type type, const std::string& name = "cast"); -tvm::Tensor CastToUInt8Bool(const tvm::Tensor& X, const std::string& name = "cast_uint8_bool"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc deleted file mode 100644 index 3394d5b7e00a2..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/concat_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Concat(const tvm::Array& inputs, - int64_t axis, - const std::string& name) { - return ConcatSafe(inputs, axis, name); -} - -// Note topi's implementation requires control flow within iterations to avoid out-of-bound access. -// Therefore, MTI implements a ConcatSafe that does not have out-of-bound access, -// and does not requires control or predicate. -tvm::Tensor ConcatSafe(const tvm::Array& inputs, - int64_t axis, - const std::string& name) { - axis = HandleNegativeAxis(axis, gsl::narrow(inputs[0]->shape.size())); - MTI_ASSERT(axis < gsl::narrow(inputs[0]->shape.size()) && "axis out of bounds"); - - tvm::Array axis_sizes; - for (auto t : inputs) { - axis_sizes.push_back(t->shape[axis]); - } - - tvm::Expr join_size = axis_sizes[0]; - for (size_t i = 1; i < axis_sizes.size(); ++i) { - join_size += axis_sizes[i]; - } - join_size = tvm::ir::Simplify(join_size); - tvm::Array out_shape; - for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { - out_shape.push_back(i == gsl::narrow(axis) ? join_size : inputs[0]->shape[i]); - } - - return tvm::compute( - out_shape, [&](const tvm::Array& ovars) { - tvm::Array indices; - - // preset - tvm::Expr min = 0; - tvm::Expr extent = axis_sizes[0]; - tvm::Expr offset = 0; - tvm::Expr ret; - - // input i = 0 - for (size_t j = 0; j < ovars.size(); ++j) { - if (j == gsl::narrow(axis)) { - tvm::Expr ivar = ovars[j]; - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - } else { - indices.push_back(ovars[j]); - } - } - ret = inputs[0](indices); - - for (size_t i = 1; i < inputs.size(); ++i) { - offset += extent; - tvm::Expr min = 0; - extent = axis_sizes[i]; - auto j = gsl::narrow(axis); - tvm::Expr ivar = ovars[j] - offset; - indices.Set(j, tvm::max(tvm::min(ivar, min + extent - 1), min)); - - ret = tvm::ir::Select::make(ivar >= 0, - inputs[i](indices), - ret); - } - - return ret; - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.h b/onnxruntime/core/codegen/mti/tensor/concat_ops.h deleted file mode 100644 index 153afebb44615..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Concat(const tvm::Array& inputs, int64_t axis, const std::string& name = "concat"); -tvm::Tensor ConcatSafe(const tvm::Array& inputs, int64_t axis, const std::string& name = "concat_safe"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/crop.cc b/onnxruntime/core/codegen/mti/tensor/crop.cc deleted file mode 100644 index 3fe569100df12..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/crop.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/crop.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Crop(const tvm::Tensor& t, - const tvm::Array& border, - const tvm::Array& scale, - const std::string& name) { - MTI_ASSERT(t->shape.size() == 4); - tvm::Expr N = t->shape[0]; - tvm::Expr C = t->shape[1]; - tvm::Expr H = t->shape[2]; - tvm::Expr W = t->shape[3]; - - MTI_ASSERT(border.size() == 4); - tvm::Expr leftBorder = border[0]; - tvm::Expr topBorder = border[1]; - tvm::Expr rightBorder = border[2]; - tvm::Expr bottomBorder = border[3]; - - tvm::Expr bottomLimit = H - bottomBorder; - tvm::Expr rightLimit = W - rightBorder; - - if (!scale.empty()) { - CHECK_EQ(scale.size(), 2); - bottomLimit = topBorder + scale[0]; - rightLimit = leftBorder + scale[1]; - } - - tvm::Array output_shape; - output_shape.push_back(tvm::ir::Simplify(N)); - output_shape.push_back(tvm::ir::Simplify(C)); - output_shape.push_back(tvm::ir::Simplify(bottomLimit - topBorder)); - output_shape.push_back(tvm::ir::Simplify(rightLimit - leftBorder)); - - auto l = [&](const tvm::Array& ovars) { - tvm::Array indices; - - indices.push_back(tvm::min(ovars[0], output_shape[0] - 1)); - indices.push_back(tvm::min(ovars[1], output_shape[1] - 1)); - indices.push_back(tvm::min(topBorder + ovars[2], topBorder + output_shape[2] - 1)); - indices.push_back(tvm::min(leftBorder + ovars[3], leftBorder + output_shape[3] - 1)); - - return t(indices); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/crop.h b/onnxruntime/core/codegen/mti/tensor/crop.h deleted file mode 100644 index ffb6a05c70504..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/crop.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Crop(const tvm::Tensor& t, - const tvm::Array& border, - const tvm::Array& scale = {}, - const std::string& name = "crop"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/expand.cc b/onnxruntime/core/codegen/mti/tensor/expand.cc deleted file mode 100644 index cdac4f56e1f9f..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/expand.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/expand.h" -#include "core/codegen/mti/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Expand(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name) { - MTI_ASSERT(new_shape.size() >= X->shape.size()); - return tvm::compute( - new_shape, - [&](const tvm::Array& out_indices) { - tvm::Array indices; - size_t broadcasted_rank = new_shape.size() - X->shape.size(); - for (size_t d = broadcasted_rank; d < new_shape.size(); ++d) { - if (tvm::is_const_int(X->shape[d - broadcasted_rank], 1)) { - indices.push_back(tvm::make_zero(HalideIR::Int(32))); - } else { - indices.push_back(out_indices[d]); - } - } - return X(indices); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/expand.h b/onnxruntime/core/codegen/mti/tensor/expand.h deleted file mode 100644 index d66d41aeb0194..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/expand.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Expand(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name = "expand"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather.cc b/onnxruntime/core/codegen/mti/tensor/gather.cc deleted file mode 100644 index 152b3981f1623..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/gather.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gather(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name) { - // handle negative axis - axis = HandleNegativeAxis(axis, gsl::narrow(t->shape.size())); - size_t axis_t = gsl::narrow(axis); - - tvm::Array output_shape; - for (size_t i = 0; i < axis_t; ++i) - output_shape.push_back(t->shape[i]); - - for (size_t i = 0; i < indices->shape.size(); ++i) - output_shape.push_back(indices->shape[i]); - - for (size_t i = axis_t + 1; i < t->shape.size(); ++i) - output_shape.push_back(t->shape[i]); - - tvm::Expr idx_upper_bound = t->shape[axis_t]; - auto l = [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (size_t i = 0; i < t->shape.size(); ++i) { - if (i < axis_t) { - ivars.push_back(ovars[i]); - } else if (i == axis_t) { - tvm::Array idx_vars; - for (size_t d = 0; d < indices->shape.size(); ++d) - idx_vars.push_back(ovars[axis_t + d]); - // make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1] - tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound); - ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); // tvm indices must be Int32 - } else { - ivars.push_back(ovars[i - 1 + indices->shape.size()]); - } - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather.h b/onnxruntime/core/codegen/mti/tensor/gather.h deleted file mode 100644 index a44bf3e4127d5..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gather(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name = "gather"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather_elements.cc b/onnxruntime/core/codegen/mti/tensor/gather_elements.cc deleted file mode 100644 index 12d2983335890..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather_elements.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/gather_elements.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor GatherElements(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name) { - tvm::Array output_shape; - int64_t indices_rank = static_cast(indices->shape.size()); - // output shape is the same as indices - for (int64_t i = 0; i < indices_rank; ++i) - output_shape.push_back(indices->shape[i]); - - tvm::Expr idx_upper_bound = t->shape[axis]; - auto l = [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (int64_t i = 0; i < indices_rank; i++) { - if (i == axis) { - tvm::Array idx_vars; - for (int64_t j = 0; j < indices_rank; j++) - idx_vars.push_back(ovars[j]); - // make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1] - tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound); - // tvm idx must be of Int(32) - ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); - } else { - ivars.push_back(ovars[i]); - } - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather_elements.h b/onnxruntime/core/codegen/mti/tensor/gather_elements.h deleted file mode 100644 index 650086f0f2e87..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather_elements.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor GatherElements(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name = "gather_elements"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/pad_ops.cc b/onnxruntime/core/codegen/mti/tensor/pad_ops.cc deleted file mode 100644 index 2f688290d109e..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/pad_ops.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/pad_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Note topi::pad does not support modes {edge, reflect} -// Therefore, MTI implements a generic Pad -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& pad_before, - const tvm::Array& pad_after, - float pad_value, - const std::string& mode, - const std::string& name) { - MTI_ASSERT(pad_before.size() >= 1); - MTI_ASSERT(pad_before.size() == pad_after.size()); - MTI_ASSERT(pad_before.size() == t->shape.size()); - - tvm::Array output_shape; - for (size_t i = 0; i < t->shape.size(); ++i) { - output_shape.push_back( - tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i])); - } - - auto l = [&](const tvm::Array& ovars) { - tvm::Array conds; - tvm::Array indices; - tvm::Array coords; - - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ivar = ovars[i] - pad_before[i]; - tvm::Expr min = 0; - tvm::Expr extent = t->shape[i]; - - conds.push_back(ivar < min); - conds.push_back(ivar >= min + extent); - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - - if (mode == "reflect") { - // calculate indices for reflect mode - tvm::Expr limit = extent - 1; - tvm::Expr coord = ivar - min; - // Avoid mod zero when tensor shape has 1, - // e.g. input shape is [1, 3, 3] instead of [3, 3] - auto* p_limit = tvm::as_const_int(limit); - if (p_limit != nullptr && *p_limit != 0) - coord = (coord + 2 * limit) % (2 * limit); // avoid negative value - coord = coord - limit; - coord = tvm::abs(coord); - coord = limit - coord; - coord = coord + min; - coords.push_back(coord); - } - } - - if (mode == "reflect") { - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - t(coords), t(indices)); - } else if (mode == "constant") { - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - tvm::make_const(t->dtype, pad_value), t(indices)); - } - - // default mode is edge - return t(indices); - }; - - return tvm::compute(output_shape, l, name); -} - -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& output_shape, - const tvm::Expr& pad_value, - const std::string& name) { - MTI_ASSERT(t->dtype == pad_value.type()); - - auto l = [&](const tvm::Array& ovars) { - tvm::Array conds; - tvm::Array indices; - - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ivar = ovars[i]; - tvm::Expr min = 0; - tvm::Expr extent = t->shape[i]; - - conds.push_back(ivar < min); - conds.push_back(ivar >= min + extent); - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - } - - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - pad_value, t(indices)); - }; - - return tvm::compute(output_shape, l, name); -} - -tvm::Tensor PadLastDim(const tvm::Tensor& t, - const int32_t align_size, - const tvm::Expr& pad_value, - const std::string& name) { - auto input_shape = t->shape; - tvm::Array out_shape; - size_t input_shape_rank = input_shape.size(); - for (size_t i = 0; i < input_shape_rank - 1; ++i) { - out_shape.push_back(input_shape[i]); - } - out_shape.push_back( - (input_shape[input_shape_rank - 1] + align_size - 1) / - align_size * align_size); - - return Pad(t, out_shape, pad_value, name + "_pad"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/pad_ops.h b/onnxruntime/core/codegen/mti/tensor/pad_ops.h deleted file mode 100644 index 6e8e350d71e97..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/pad_ops.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// ONNX Pad semantics -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& pad_before, - const tvm::Array& pad_after, - float pad_value = 0.0f, - const std::string& mode = "constant", - const std::string& name = "pad"); - -// Other common Pad interfaces -// Pad for a given shape -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& output_shape, - const tvm::Expr& pad_value, - const std::string& name = "pad"); - -// Pad for the last dim only. -// This is widely used for weight layout to guard alignment -tvm::Tensor PadLastDim(const tvm::Tensor& t, - const int32_t align_size, - const tvm::Expr& pad_value, - const std::string& name = "pad_last_dim"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc b/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc deleted file mode 100644 index 817fb32c2837a..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/reshape_ops.h" - -#include "core/codegen/mti/common.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Flatten(const tvm::Tensor& X, int64_t axis, const std::string& name) { - const auto& input_shape = X->shape; - return Reshape(X, {SizeToDimension(input_shape, axis), SizeFromDimension(input_shape, axis)}, name); -} - -tvm::Tensor Identity(const tvm::Tensor& X, const std::string& name) { - return Reshape(X, X->shape, name); -} - -tvm::Tensor Reshape(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name) { - if (new_shape.size() > 0) { - auto X_dim = SizeToDimension(X->shape, X->shape.size()); - auto new_dim = SizeToDimension(new_shape, new_shape.size()); - auto* pX_dim = tvm::as_const_int(X_dim); - auto* pNew_dim = tvm::as_const_int(new_dim); - - if (pX_dim != nullptr && pNew_dim != nullptr) { - MTI_ASSERT(*pX_dim == *pNew_dim); - } - return topi::reshape(X, new_shape, name); - } else { - // generate empty dim tensor with origial input data value - tvm::Array tmp_shape; - tmp_shape.push_back(1); - auto tmp_tensor = topi::reshape(X, tmp_shape); - return tvm::compute( - new_shape, - [&](const tvm::Array&) { - return tmp_tensor[0]; - }, - name); - } -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/reshape_ops.h b/onnxruntime/core/codegen/mti/tensor/reshape_ops.h deleted file mode 100644 index e23d62e4c57b0..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/reshape_ops.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Flatten(const tvm::Tensor& X, int64_t axis, const std::string& name = "flatten"); -tvm::Tensor Identity(const tvm::Tensor& X, const std::string& name = "identity"); -tvm::Tensor Reshape(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name = "reshape"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/shape_op.cc b/onnxruntime/core/codegen/mti/tensor/shape_op.cc deleted file mode 100644 index b51bd67a8b2dc..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/shape_op.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/shape_op.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Shape(const tvm::Tensor& X, const std::string& name) { - int ndim = static_cast(X->shape.size()); - tvm::Array out_shape{ndim}; - return tvm::compute( - out_shape, [&](const tvm::Array& indices) { - auto idx = indices[0]; - tvm::Expr ret = 0; - for (int i = 0; i < ndim; ++i) { - ret = tvm::ir::Select::make(idx == i, X->shape[i], ret); - } - return tvm::cast(HalideIR::Int(64), ret); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/slice.cc b/onnxruntime/core/codegen/mti/tensor/slice.cc deleted file mode 100644 index 6cbab43584d4b..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/slice.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/slice.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// local constexpr for INT_MAX -constexpr int64_t max_range = INT_MAX; - -tvm::Expr position(const tvm::Expr& dim, const tvm::Integer& offset, bool allow_out_of_bound = false) { - if (offset->value >= max_range) { - return allow_out_of_bound ? dim : dim - 1; - } else if (offset->value <= -max_range) { - return tvm::make_const(HalideIR::Int(32), allow_out_of_bound ? -1 : 0); - } else { - if (offset->value >= 0) { - return tvm::ir::Simplify(tvm::ir::Min::make(offset, dim + (allow_out_of_bound ? 0 : -1))); - } else { - return tvm::ir::Simplify(dim + tvm::ir::Max::make(offset, -dim + (allow_out_of_bound ? -1 : 0))); - } - } -} - -tvm::Tensor Slice(const tvm::Tensor& X, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes1, - const std::vector& steps, - const std::string& name) { - MTI_ASSERT(starts.size() == ends.size()); - MTI_ASSERT(starts.size() == axes1.size()); - MTI_ASSERT(starts.size() == steps.size()); - - std::vector axes; - for (const auto& i : axes1) { - axes.push_back(HandleNegativeAxis(i, X->shape.size())); - } - - tvm::Array output_shape; - bool empty = false; - for (int64_t i = 0; i < gsl::narrow(X->shape.size()); ++i) { - auto axes_iter = std::find(axes.begin(), axes.end(), i); - if (axes_iter != axes.end()) { - auto axis = axes_iter - axes.begin(); - tvm::Expr start = position(X->shape[i], starts[axis]); - tvm::Expr end = position(X->shape[i], ends[axis], /*allow_out_of_bound*/ true); - auto dim = tvm::ir::Simplify((end - start + tvm::Integer(steps[axis] + (steps[axis] < 0 ? 1 : -1))) / tvm::Integer(steps[axis])); - auto int_dim = tvm::as_const_int(dim); - if (int_dim && *int_dim <= 0) { - output_shape.push_back(0); - empty = true; - } else { - output_shape.push_back(dim); - } - } else { - output_shape.push_back(X->shape[i]); - } - } - - if (empty) { - return MakeZeroTensor(output_shape, X->dtype, name); - } - - return tvm::compute( - output_shape, - [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (size_t i = 0; i < X->shape.size(); ++i) { - auto axes_iter = std::find(axes.begin(), axes.end(), i); - if (axes_iter != axes.end()) { - auto axis = axes_iter - axes.begin(); - ivars.push_back(tvm::ir::Simplify(ovars[i] * tvm::Integer(steps[axis]) + position(X->shape[i], starts[axis]))); - } else { - ivars.push_back(ovars[i]); - } - } - return X(ivars); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/slice.h b/onnxruntime/core/codegen/mti/tensor/slice.h deleted file mode 100644 index ac5c9437791f6..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/slice.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Slice(const tvm::Tensor& X, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes, - const std::vector& steps, - const std::string& name = "slice"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/split.cc b/onnxruntime/core/codegen/mti/tensor/split.cc deleted file mode 100644 index 6ee366314858f..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/split.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/split.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Similar to numpy, topi::split takes split indices rather than the -// sizes of the splits. Thus we implement our own. -tvm::Array Split(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name) { - MTI_ASSERT(axis < gsl::narrow(X->shape.size())); - size_t axis_t = gsl::narrow(axis); - - tvm::Array> output_shapes; - int num_splits = gsl::narrow(split_sizes.size()); - for (auto& s : split_sizes) { - tvm::Array shape; - for (size_t i = 0; i < axis_t; i++) { - shape.push_back(X->shape[i]); - } - shape.push_back(s); - for (size_t i = axis_t + 1; i < X->shape.size(); i++) { - shape.push_back(X->shape[i]); - } - output_shapes.push_back(shape); - } - - tvm::Array res; - int idx = 0; - for (int i_split = 0; i_split < num_splits; ++i_split) { - tvm::Expr s = split_sizes[i_split]; - auto l = [&](const tvm::Array& indices) { - tvm::Array new_indices; - for (size_t i = 0; i < axis_t; i++) { - new_indices.push_back(indices[i]); - } - new_indices.push_back(indices[axis_t] + idx); - for (size_t i = axis_t + 1; i < X->shape.size(); i++) { - new_indices.push_back(indices[i]); - } - MTI_ASSERT(topi::detail::IsConstInt(s)); - MTI_ASSERT(new_indices.size() == X->shape.size()); - int size = topi::detail::GetConstInt(s); - idx += size; - return X(new_indices); - }; - res.push_back(tvm::compute(output_shapes[i_split], l, name)); - } - - MTI_ASSERT(topi::detail::IsConstInt(X->shape[axis_t])); - int size_of_splitted_axis = static_cast(topi::detail::GetConstInt(X->shape[axis_t])); - MTI_ASSERT(idx == size_of_splitted_axis); - return res; -} - -tvm::Array SplitWithIndices(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name) { - return topi::split(X, split_sizes, gsl::narrow(axis), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/split.h b/onnxruntime/core/codegen/mti/tensor/split.h deleted file mode 100644 index bcb9c47d936dd..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/split.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// ONNX Split semantics -tvm::Array Split(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name = "split"); - -// Another common Split interface -// Split with chunck indices -tvm::Array SplitWithIndices(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name = "split_with_indices"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/tile.cc b/onnxruntime/core/codegen/mti/tensor/tile.cc deleted file mode 100644 index 2fef86adcbaea..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/tile.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/tile.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Tile(const tvm::Tensor& t, - const std::vector& repeats, - const std::string& name) { - MTI_ASSERT(repeats.size() == t->shape.size()); - tvm::Array output_shape; - - bool repeats_zero = false; - for (size_t i = 0; i < t->shape.size(); ++i) { - if (repeats[i] == 0) - repeats_zero = true; - output_shape.push_back(t->shape[i] * gsl::narrow(repeats[i])); - } - - auto l = [&](const tvm::Array& ovars) { - if (repeats_zero) - return tvm::make_zero(t->dtype); - - tvm::Array ivars; - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ovar = ovars[i]; - ivars.push_back(ovar % t->shape[i]); - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/tile.h b/onnxruntime/core/codegen/mti/tensor/tile.h deleted file mode 100644 index 7ce331fb5ea95..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/tile.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Tile(const tvm::Tensor& t, - const std::vector& repeats, - const std::string& name = "tile"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/transpose.cc b/onnxruntime/core/codegen/mti/tensor/transpose.cc deleted file mode 100644 index 873ff8d7f1708..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/transpose.cc +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/transpose.h" - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Transpose(const tvm::Tensor& X, const tvm::Array& axes, const std::string& name) { - return topi::transpose(X, axes, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/transpose.h b/onnxruntime/core/codegen/mti/tensor/transpose.h deleted file mode 100644 index a2a98fedf1e79..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/transpose.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Transpose(const tvm::Tensor& X, - const tvm::Array& axes, - const std::string& name = "transpose"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/where.cc b/onnxruntime/core/codegen/mti/tensor/where.cc deleted file mode 100644 index 2bdac3cae7ef5..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/where.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/where.h" - -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Where(const tvm::Tensor& B, - const tvm::Tensor& X, - const tvm::Tensor& Y, - const std::string& name) { - size_t rank = std::max(std::max(B->shape.size(), X->shape.size()), Y->shape.size()); - tvm::Array output_shape; - for (size_t i = 0; i < rank; ++i) { - tvm::Expr dim = tvm::make_const(HalideIR::Int(32), 1); - bool broadcasted = - BroadcastDim(B->shape, i, rank, dim) && - BroadcastDim(X->shape, i, rank, dim) && - BroadcastDim(Y->shape, i, rank, dim); - MTI_ASSERT(broadcasted); - output_shape.push_back(dim); - } - - return topi::where(topi::broadcast_to(B, output_shape), - topi::broadcast_to(X, output_shape), - topi::broadcast_to(Y, output_shape), - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/where.h b/onnxruntime/core/codegen/mti/tensor/where.h deleted file mode 100644 index 68c5288eb3580..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/where.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Where(const tvm::Tensor& B, - const tvm::Tensor& X, - const tvm::Tensor& Y, - const std::string& name = "where"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h b/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h deleted file mode 100644 index 1463e50bd72fb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/codegen/common/op_macro.h" -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// This macro declares a TVM IR builder -// based on ORT OP type with postfix DefaultTVM -#define DECLARE_GENERIC_OP_IR_CREATOR_CLASS(OP) \ - DECLARE_OP_IR_CREATOR_CLASS(OP, DefaultTVM) - -// This macro returns a TVM IR builder class name -// based ORT OP type with postfix DefaultTVM -#define GENERIC_OP_IR_CREATOR_CLASS(OP) \ - CREATOR_CLASS(OP, DefaultTVM##IRCreator) - -#define GENERIC_OP_IR_CREATOR_STRING(OP) \ - STRINGIZE(GENERIC_OP_IR_CREATOR_CLASS(OP)) - -// define all ops for DefaultTVM -#define ADD_OP_ITEM(OP) DECLARE_GENERIC_OP_IR_CREATOR_CLASS(OP) -#define BINARY_OP(OP) ADD_OP_ITEM(OP) -#define BINARY_CMP_OP(OP) ADD_OP_ITEM(OP) -#define POOL_OP(OP) ADD_OP_ITEM(OP) -#define UNARY_OP(OP) ADD_OP_ITEM(OP) -#define VARIADIC_OP(OP) ADD_OP_ITEM(OP) -#define REDUCE_INDEXED_OP(OP) ADD_OP_ITEM(OP) -#define REDUCE_OP(OP) ADD_OP_ITEM(OP) - -LIST_ALL_GENERIC_OPS() - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -#undef VARIADIC_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc deleted file mode 100644 index 9452146621ac7..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// helper local macro defines Evaluate of BINARY_OP OpIRCreators -#define BINARY_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = name(inputs[0], inputs[1], node.Name()); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_BINARY_OPS() - -#undef BINARY_OP - -// helper local macro defines Evaluate of BINARY_CMP_OP OpIRCreators -#define BINARY_CMP_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Cast(name(inputs[0], inputs[1], node.Name()), HalideIR::UInt(8), "cast_bool_" #name); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_BINARY_CMP_OPS() - -#undef BINARY_CMP_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc deleted file mode 100644 index bb33e6e70accf..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/unary_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Clip OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Clip)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); - tvm::Expr min_value, max_value; - if (version < 11) { - float max_v, min_v; - info.GetAttrOrDefault("min", &min_v, std::numeric_limits::lowest()); - info.GetAttrOrDefault("max", &max_v, std::numeric_limits::max()); - min_value = tvm::make_const(tvm::Float(32), min_v); - max_value = tvm::make_const(tvm::Float(32), max_v); - } else { - // for op_version >= 11, max and min are optional inputs - min_value = tvm::make_const(tvm::Float(32), std::numeric_limits::lowest()); - max_value = tvm::make_const(tvm::Float(32), std::numeric_limits::max()); - auto num_inputs = inputs.size(); - if (num_inputs >= 2 && inputs[1].defined()) { - min_value = inputs[1](); - } - if (num_inputs == 3 && inputs[2].defined()) { - max_value = inputs[2](); - } - } - - tvm::Tensor Y = Clip(inputs[0], min_value, max_value, node.Name() + "_Clip"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc deleted file mode 100644 index 64f995076e1bb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/gemm.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Gemm OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Gemm)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& /*ctx_codegen*/, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - tvm::Tensor A = inputs[0]; - tvm::Tensor B = inputs[1]; - tvm::Tensor C = inputs[2]; - - int64_t trans_A, trans_B; - ORT_RETURN_IF_ERROR(attrs.GetAttr("transA", &trans_A)); - ORT_RETURN_IF_ERROR(attrs.GetAttr("transB", &trans_B)); - - float alpha, beta; - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha).IsOK()); - ORT_ENFORCE(attrs.GetAttr("beta", &beta).IsOK()); - - tvm::Tensor Y = Gemm(A, B, C, trans_A != 0, trans_B != 0, alpha, beta, node.Name() + "_Gemm"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc deleted file mode 100644 index cb09518bf63d1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/logsoftmax.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of LogSoftmax OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(LogSoftmax)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis_i64; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); - axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); - - tvm::Tensor Y = LogSoftmax(inputs[0], axis_i64, node.Name() + "_LogSoftmax"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc deleted file mode 100644 index ab1ac237bfa5d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/matmul_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of MatMul OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(MatMul)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - tvm::Tensor Y = MatMul(inputs[0], inputs[1], node.Name() + "_MatMul"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc deleted file mode 100644 index 6f66b1f1a2afb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of MatMulInteger OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(MatMulInteger)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - const auto& A = inputs[0]; - const auto& B = inputs[1]; - auto& name = node.Name(); - - // A generic path, cast to int32 - // Support skipped trailing inputs - auto A_Int32 = (node.InputDefs().size() >= 3 && node.InputDefs()[2]->Exists()) - ? Sub(Cast(A, HalideIR::Int(32)), Cast(inputs[2], HalideIR::Int(32))) - : Cast(A, HalideIR::Int(32)); - auto B_Int32 = (node.InputDefs().size() >= 4 && node.InputDefs()[3]->Exists()) - ? Sub(Cast(B, HalideIR::Int(32)), Cast(inputs[3], HalideIR::Int(32))) - : Cast(B, HalideIR::Int(32)); - tvm::Tensor Y = MatMul(A_Int32, B_Int32, name + "_MatMulInteger"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc deleted file mode 100644 index f29a3f3e7cdf7..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/reduce_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -using ReduceIndexedFunc = tvm::Tensor (*)(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name); -using ReduceFunc = tvm::Tensor (*)(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name); - -// helper class for for REDUCE_INDEXED_OP -class FuncReduceIndexed { - public: - FuncReduceIndexed(const Node& node, ReduceIndexedFunc func, const std::string& name) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - axis_ = info.GetAttrOrDefault("axis", 0); - int64_t keepdims_i = 1; - ORT_ENFORCE(info.GetAttr("keepdims", &keepdims_i).IsOK()); - keep_dims_ = (keepdims_i == 1); - func_ = func; - name_ = name; - } - - tvm::Tensor operator()(const tvm::Tensor& X) const { - auto axis = HandleNegativeAxis(axis_, gsl::narrow_cast(X->shape.size())); - tvm::Tensor index32 = func_(X, axis, keep_dims_, name_); - return Cast(index32, tvm::Int(64)); - } - - private: - int64_t axis_; - bool keep_dims_; - ReduceIndexedFunc func_; - std::string name_; -}; - -// helper class for REDUCE_OP -class FuncReduce { - public: - FuncReduce(const Node& node, ReduceFunc func, const std::string& name) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - axes_ = info.GetAttrsOrDefault("axes"); - int64_t keepdims_i = 1; - ORT_ENFORCE(info.GetAttr("keepdims", &keepdims_i).IsOK()); - keep_dims_ = (keepdims_i == 1); - func_ = func; - name_ = name; - } - - tvm::Tensor operator()(const tvm::Tensor& X) const { - std::vector axes; - for (auto i : axes_) - axes.push_back(HandleNegativeAxis(i, gsl::narrow_cast(X->shape.size()))); - - return func_(X, axes, keep_dims_, name_); - } - - private: - std::vector axes_; - bool keep_dims_; - ReduceFunc func_; - std::string name_; -}; - -// helper macro defines Evaluate of REDUCE_OP OpIRCreators -#define REDUCE_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y; \ - if (ShapeRank(node.OutputDefs()[0]) == 0) { \ - tvm::Tensor temp = FuncReduce(node, &name, #name)(inputs[0]); \ - Y = Reshape(temp, {}); \ - } else { \ - Y = FuncReduce(node, &name, #name)(inputs[0]); \ - } \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -// helper macro defines Evaluate of REDUCE_INDEXED_OP OpIRCreators -#define REDUCE_INDEXED_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = FuncReduceIndexed(node, &name, #name)(inputs[0]); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_REDUCE_OPS() - -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc deleted file mode 100644 index 7b13de5a94e48..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/softmax.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Softmax OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Softmax)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis_i64; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); - - axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); - tvm::Tensor Y = Softmax(inputs[0], axis_i64, node.Name() + "_Softmax"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h b/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h deleted file mode 100644 index 29e6519af0ef1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { -// helper class for unary_ops with alpha -class FuncWithAlpha { - public: - FuncWithAlpha(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - } - - protected: - float alpha_; -}; - -// helper class for unary_ops with alpha and beta -class FuncWithAlphaBeta { - public: - FuncWithAlphaBeta(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(attrs.GetAttr("beta", &beta_).IsOK()); - } - - protected: - float alpha_; - float beta_; -}; - -// helper class for unary_ops with alpha and gamma -class FuncWithAlphaGamma { - public: - FuncWithAlphaGamma(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(attrs.GetAttr("gamma", &gamma_).IsOK()); - } - - protected: - float alpha_; - float gamma_; -}; -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc deleted file mode 100644 index 0407c0a06abf6..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/passes/op_ir_creator/math/unary_funcs.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// helper macro declares unary_ops helper class without attribute -#define FuncClass(name) \ - class Func##name { \ - public: \ - Func##name(const Node&) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha -#define FuncClassAlpha(name) \ - class Func##name : public FuncWithAlpha { \ - public: \ - Func##name(const Node& node) : FuncWithAlpha(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha and beta -#define FuncClassAlphaBeta(name) \ - class Func##name : public FuncWithAlphaBeta { \ - public: \ - Func##name(const Node& node) : FuncWithAlphaBeta(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_, beta_); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha and gamma -#define FuncClassAlphaGamma(name) \ - class Func##name : public FuncWithAlphaGamma { \ - public: \ - Func##name(const Node& node) : FuncWithAlphaGamma(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_, gamma_); \ - } \ - } - -FuncClass(Abs); -FuncClassAlphaBeta(Affine); -FuncClass(Ceil); -FuncClassAlpha(Elu); -FuncClass(Exp); -FuncClass(Floor); -FuncClassAlphaBeta(HardSigmoid); -FuncClassAlpha(LeakyRelu); -FuncClass(Log); -FuncClass(Neg); -FuncClassAlphaBeta(ParametricSoftplus); -FuncClass(Reciprocal); -FuncClass(Relu); -FuncClassAlphaBeta(ScaledTanh); -FuncClassAlphaGamma(Selu); -FuncClass(Sigmoid); -FuncClass(Softplus); -FuncClass(Softsign); -FuncClass(Sqrt); -FuncClass(Tanh); -FuncClassAlpha(ThresholdedRelu); - -// helper macro defines Evaluate of UNARY_OP OpIRCreators -#define UNARY_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Func##name(node)(inputs[0]); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -// helper local macros to replace some calls in LIST_UNARY_OPS -LIST_UNARY_OPS() - -#undef UNARY_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc deleted file mode 100644 index 9559a713c2876..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Sum(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return Add(lhs, rhs, name); -} - -// helper local macro defines Evaluate of BINARY_OP OpIRCreators -#define VARIADIC_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Identity(inputs[0], node.Name() + "0"); \ - for (size_t i = 1; i < inputs.size(); ++i) \ - Y = name(Y, inputs[i], node.Name() + std::to_string(i)); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_VARIADIC_OPS() - -#undef VARIADIC_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc deleted file mode 100644 index 19545d1554405..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/nn/conv_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/concat_ops.h" -#include "core/codegen/mti/tensor/split.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -Status GENERIC_OP_IR_CREATOR_CLASS(Conv)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - // Attributes - int64_t group; - std::string auto_pad; - std::vector kernel_shape, strides, dilations, pads; - - info.GetAttrOrDefault("group", &group, 1); - info.GetAttrOrDefault("auto_pad", &auto_pad, "NOTSET"); - - ORT_THROW_IF_ERROR(info.GetAttrs("kernel_shape", kernel_shape)); - ORT_ENFORCE(kernel_shape.size() <= 2, "Only support 1D/2D convolution currently!"); - ORT_THROW_IF_ERROR(info.GetAttrs("strides", strides)); - - dilations = info.GetAttrs("dilations", dilations).IsOK() ? dilations : std::vector(kernel_shape.size(), 1); - ORT_ENFORCE(dilations == std::vector(kernel_shape.size(), 1), "Only support dilation is 1 currently"); - - pads = info.GetAttrs("pads", pads).IsOK() ? pads : std::vector(kernel_shape.size() * 2, 0); - - // auto_pad - if (auto_pad != "NOTSET") { - auto rank = inputs[0]->shape.size() - 2; - ORT_ENFORCE(rank > 0); - for (uint64_t i = 0; i < rank; i++) { - if (auto_pad == "VALID") { - pads[i] = 0; - pads[i + rank] = 0; - } else if (auto_pad == "SAME_UPPER" || auto_pad == "SAME_LOWER") { - // TODO: handle symbolic dim - ORT_ENFORCE(ShapeHasValue(node.InputDefs()[0], 2 + i)); - - int64_t input_dim_value = ShapeValue(node.InputDefs()[0], 2 + i); - int64_t output_dim_value = (input_dim_value + strides[i] - 1) / strides[i]; - int64_t pad_needed = (output_dim_value - 1) * strides[i] + kernel_shape[i] - input_dim_value; - - pads[i] = auto_pad == "SAME_LOWER" ? (pad_needed + 1) / 2 : pad_needed / 2; - pads[i + rank] = pad_needed - pads[i]; - } else { - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown auto_pad value"); - } - } - } - - // Inputs - tvm::Tensor X = inputs[0]; - tvm::Tensor W = inputs[1]; - // Outputs - tvm::Tensor Y; - tvm::Array Y_shape = ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen); - - // 1-D convolution - if (kernel_shape.size() == 1) { - Y = Conv1D(X, W, Y_shape, ToTvmArray(strides), ToTvmArray(pads), node.Name() + "_Conv1D"); - } - // 2-D convolution - else if (kernel_shape.size() == 2) { - if (group == 1) { - Y = Conv2D(X, W, Y_shape, ToTvmArray(strides), ToTvmArray(pads), node.Name() + "_Conv2D"); - } else { - int64_t channel_out = ShapeValue(node.InputDefs()[1], 0); - int64_t channel_in = ShapeValue(node.InputDefs()[1], 1); - ORT_ENFORCE(channel_out % group == 0); - - int64_t cout_group = channel_out / group; - Y_shape.Set(1, Y_shape[1] / gsl::narrow_cast(group)); - - tvm::Array split_index0; - tvm::Array split_index1; - - for (int i = 1; i < group; i++) { - split_index0.push_back(i * channel_in); - split_index1.push_back(i * cout_group); - } - - auto input_groups = SplitWithIndices(X, split_index0, 1); - auto weight_groups = SplitWithIndices(W, split_index1, 0); - - // FIXME: This will trigger a llvm buffer overflow when group is too large - // TODO: fix this change it to batched gemm/conv - tvm::Array output_tensors; - for (int i = 0; i < group; i++) { - auto output_tensor = Conv2D(input_groups[i], - weight_groups[i], - Y_shape, - ToTvmArray(strides), - ToTvmArray(pads), - node.Name() + "_Conv2D"); - output_tensors.push_back(output_tensor); - } - Y = Concat(output_tensors, 1); - } - } - - // Add bias if provided - // Support skipped trailing inputs - if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { - tvm::Tensor B = inputs[2]; - Y = tvm::compute( - Y_shape, - [&](const tvm::Array& indices) { - return Y(indices) + B(indices[1]); - }); - } - - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc deleted file mode 100644 index 88170bb56dd2d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/nn/lstm.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// In the cell computation, we don't have the "direction" dimension and sequence dimension, -// which have been processed outside of the cell. -// Here we implement an LTSM cell. -// For those args (inputs/outputs) of hidden states we put AFTER regular args (inputs/outputs) -// with a pre-defined order -// In a LSTM, the order is H and then C. -// Ouputs of LSTM is Y_h and then Y_c -Status GENERIC_OP_IR_CREATOR_CLASS(LSTM)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - std::string direction_attr; - ORT_RETURN_IF_ERROR(attrs.GetAttr("direction", &direction_attr)); - int64_t hidden_size; - ORT_RETURN_IF_ERROR(attrs.GetAttr("hidden_size", &hidden_size)); - - // input tensor with shape [seq_length, batch_size, input_size] - const tvm::Tensor& X = inputs[0]; // input tensor with shape [seq_length, batch_size, input_size] - const tvm::Tensor& W = inputs[1]; // weights tensor with shape [4*hidden_size, input_size] - const tvm::Tensor& R = inputs[2]; // recurrence tensor with shape [4*hidden_size, hidden_size] - const tvm::Tensor& B = inputs[3]; // optional bias tensor with shape [8*hidden_size] - bool has_B = node.InputDefs()[3]->Exists(); - - // Unsupported the 4th inputs - // optional tensor specifying sequence lengths in a batch, shape: [batch_size] - // const tvm::Tensor* seq_len = inputs[4] ? &inputs[4]->tensor : nullptr; - - const tvm::Tensor& prev_H = inputs[5]; // optional initial H, shape: [batch_size, hidden_size] - const tvm::Tensor& prev_C = inputs[6]; // optional initial C, shape: [batch_size, hidden_size] - - const tvm::Tensor& P = inputs[7]; // optional peepholes tensor with shape [3*hidde_size] - bool has_P = node.InputDefs()[7]->Exists(); - - tvm::Tensor Y_h; // shape: [batch_size, hidden_size] - tvm::Tensor Y_c; // shape: [batch_size, hidden_size] - LSTMAttributes lstm_attrs(hidden_size); - LSTM_cell(lstm_attrs, X, W, R, B, has_B, prev_H, prev_C, P, has_P, Y_h, Y_c); - - // Since we only generate lstm cell, lstm's states need to be always outputs, - // regardless whethere they are skipped or not. - // The skipped trailing outputs need to be handled by Execution - outputs.push_back(Y_h); - outputs.push_back(Y_c); - - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc deleted file mode 100644 index 84d3b7c1e0f79..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/nn/pool_ops.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/cpu/nn/pool_attributes.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// A local macro to create Pool Ops - -// helper macro defines Evaluate of of POOL_OP OpIRCreators -#define POOL_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext& ctx_codegen, \ - tvm::Array& outputs) { \ - ORT_RETURN_IF_NOT(outputs.size() == 1, "multiple outputs are not supported yet!"); \ - ProtoHelperNodeContext ctx(node); \ - OpNodeProtoHelper info(&ctx); \ - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); \ - PoolAttributes pool_attrs(info, #name, version); \ - for (auto n : pool_attrs.dilations) { \ - ORT_RETURN_IF_NOT(n <= 1, "dilations are not supported yet!"); \ - } \ - if (pool_attrs.global_pooling) { \ - if (inputs[0]->shape.size() != 4) { \ - ORT_NOT_IMPLEMENTED(gsl::narrow_cast(inputs[0]->shape.size()) - 2, "d global pooling is not implementated"); \ - } \ - } else { \ - if (pool_attrs.kernel_shape.size() != 2) { \ - ORT_NOT_IMPLEMENTED(pool_attrs.kernel_shape.size(), "d pooling is not implementated"); \ - } \ - } \ - tvm::Array dummy_output_shape; \ - tvm::Tensor Y = name(inputs[0], pool_attrs, dummy_output_shape); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_POOL_OPS() - -#undef POOL_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc deleted file mode 100644 index bd324fd359edf..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/cast_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Cast OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Cast)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t to; - ORT_RETURN_IF_ERROR(attrs.GetAttr("to", &to)); - auto to_type_proto = gsl::narrow_cast(to); - - tvm::Tensor X = inputs[0]; - tvm::Tensor Y; - if (to_type_proto == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { - // special case for bool as ONNX bool is uint8, while in tvm it's uint1 - Y = CastToUInt8Bool(X, node.Name() + "_Cast"); - } else { - Y = Cast(X, ToTvmType(to_type_proto), node.Name() + "_Cast"); - } - - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc deleted file mode 100644 index 418296889419e..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/concat_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Concat OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Concat)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis)); - - tvm::Tensor Y = Concat(inputs, axis, node.Name() + "_Concat"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc deleted file mode 100644 index 3b6a9a76f0723..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/crop.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Crop OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Crop)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - if (inputs[0]->shape.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input is expected to have four dimensions corresponding to [N,C,H,W]"); - } - - std::vector border; - std::vector scale; - - ORT_ENFORCE(attrs.GetAttrs("border", border).IsOK()); - // scale is optional and status is false when omit - bool is_ok = attrs.GetAttrs("scale", scale).IsOK(); - ORT_UNUSED_PARAMETER(is_ok); - - if (border.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Attribute border needs to be specified with four border elements"); - } - - tvm::Tensor Y = Crop(inputs[0], ToTvmArray(border), ToTvmArray(scale), node.Name() + "_Crop"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc deleted file mode 100644 index 0f0e0cf0987b3..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/expand.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Expand OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Expand)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Expand(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Expand"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc deleted file mode 100644 index 3a5d801b6839f..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/gather.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Gather OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Gather)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_ENFORCE(attrs.GetAttr("axis", &axis).IsOK()); - - tvm::Tensor Y = Gather(inputs[0], axis, inputs[1], node.Name() + "_Gather"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc deleted file mode 100644 index 0b71506cceed3..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/gather_elements.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of GatherElements OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(GatherElements)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_ENFORCE(attrs.GetAttr("axis", &axis).IsOK()); - axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputs[0]->shape.size())); - - tvm::Tensor Y = GatherElements(inputs[0], axis, inputs[1], node.Name() + "_GatherElements"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc deleted file mode 100644 index e9e20e8a43998..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/pad_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Pad OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Pad)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - std::string mode; - std::vector pads; - float value; - - ORT_THROW_IF_ERROR(attrs.GetAttr("mode", &mode)); - ORT_THROW_IF_ERROR(attrs.GetAttrs("pads", pads)); - ORT_THROW_IF_ERROR(attrs.GetAttr("value", &value)); - - if (mode != "constant" && mode != "edge" && mode != "reflect") - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: Unsupported padding mode!"); - - if (pads.size() != 2 * inputs[0]->shape.size()) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: pads rank does not match inputs rank!"); - - std::vector pad_before, pad_after; - size_t offset = pads.size() / 2; - for (size_t i = 0; i < offset; i++) { - pad_before.push_back(pads[i]); - pad_after.push_back(pads[i + offset]); - } - - tvm::Tensor Y = Pad(inputs[0], ToTvmArray(pad_before), ToTvmArray(pad_after), value, mode, node.Name() + "_Pad"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc deleted file mode 100644 index a83f598bc8ad1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Dropout OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Dropout)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Identity(inputs[0]); - outputs.push_back(Y); - - // optional mask - // Support skipped trailing outputs - if (node.OutputDefs().size() > 1 && node.OutputDefs()[1]->Exists()) { - // A fake mask with all ones - auto l = [&](const tvm::Array& /*indices*/) { - return tvm::make_const(tvm::UInt(8), 1); - }; - tvm::Tensor mask = tvm::compute(inputs[0]->shape, l, "mask"); - outputs.push_back(mask); - } - - return Status::OK(); -} - -// Evaluate of Flatten OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Flatten)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(attrs.GetAttr("axis", &axis)); - - tvm::Tensor Y = Flatten(inputs[0], axis, node.Name() + "_Flatten"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Identity OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Identity)::Evaluate( - const tvm::Array& inputs, - const Node&, - CodeGenContext&, - tvm::Array& outputs) { - tvm::Tensor Y = Identity(inputs[0]); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Reshape OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Reshape)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Reshape"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Squeeze OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Squeeze)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Squeeze"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Unsqueeze OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Unsqueeze)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Unsqueeze"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc deleted file mode 100644 index 84761ecac1397..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/shape_op.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Expand OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Shape)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Shape(inputs[0], node.Name() + "_Expand"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc deleted file mode 100644 index 6a016580c41e4..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/slice.h" -#include "core/framework/op_kernel_info.h" -#include "core/framework/tensorprotoutils.h" - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -Status SliceCommon(const tvm::Array& inputs, - const Node& node, - tvm::Array& outputs, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes1, - const std::vector& steps1) { - ORT_RETURN_IF_NOT(nullptr != node.InputDefs()[0], "nullptr == node.InputDefs()[0]"); - - std::vector axes; - if (axes1.size() == 0) { - for (size_t i = 0; i < starts.size(); ++i) { - axes.push_back(gsl::narrow_cast(i)); - } - } else { - axes = axes1; - } - - std::vector steps; - if (steps1.size() == 0) { - steps.resize(starts.size(), 1); - } else { - steps = steps1; - } - - tvm::Tensor Y = Slice(inputs[0], starts, ends, axes, steps, node.Name() + "_Slice"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Slice OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Slice)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - // NOTE that in opset 10, Slice has changed starts/ends/axes from attribute to input - // which may lead to dynamic output shape. - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); - ORT_RETURN_IF_NOT(version <= 9, "Dynamic Slice is not supported yet"); - - std::vector starts, ends, steps; - ORT_RETURN_IF_ERROR(info.GetAttrs("starts", starts)); - ORT_RETURN_IF_ERROR(info.GetAttrs("ends", ends)); - ORT_RETURN_IF_NOT(starts.size() == ends.size(), "starts.size() != ends.size()"); - - auto axes = info.GetAttrsOrDefault("axes"); - - return SliceCommon(inputs, node, outputs, starts, ends, axes, steps); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc deleted file mode 100644 index ec52d98b5bf96..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/split.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Split OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Split)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis)); - axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputs[0]->shape.size())); - std::vector split_sizes; - - int64_t split_size_sum = 0; - if (info.GetAttrs("split", split_sizes).IsOK()) { - // optional - split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); - ORT_RETURN_IF_NOT(std::all_of(split_sizes.cbegin(), split_sizes.cend(), [](int64_t value) { return value > 0; }), - "Invalid value in 'split' attribute. All values must be > 0"); - - // check split sizes - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - ORT_RETURN_IF_NOT(split_sizes[i] == ShapeValue(node.OutputDefs()[i], gsl::narrow(axis)), - "split_sizes[i] != ShapeValue(node.OutputDefs()[i], axis)"); - } - - } else { - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - split_sizes.push_back(ShapeValue(node.OutputDefs()[i], gsl::narrow(axis))); - split_size_sum += split_sizes[i]; - } - } - - // check total size - if (ShapeHasValue(node.InputDefs()[0], axis)) { - int64_t input_axis_dim = ShapeValue(node.InputDefs()[0], axis); - if (split_size_sum != input_axis_dim) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Cannot split using values in 'split' attribute. Axis=", axis, - " Dim being splitted=", input_axis_dim, - " Sum of sizes in 'split' (must equal size of selected axis) was ", split_size_sum); - } - } - - tvm::Array output_tensors = Split(inputs[0], ToTvmArray(split_sizes), axis, node.Name() + "_Split"); - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - outputs.push_back(output_tensors[i]); - } - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc deleted file mode 100644 index 43999ebd1f465..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/transpose.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Transpose OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Transpose)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - size_t input_0_shape_rank = inputs[0]->shape.size(); - std::vector permute; - bool is_ok = attrs.GetAttrs("perm", permute).IsOK(); - if (permute.size() != 0 && permute.size() != input_0_shape_rank) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose: Incorrect permute size"); - - std::vector default_permute; - const std::vector* perm; - // either we don't have perm attribute or the perm attribute is empty - bool use_default_perm = !is_ok || permute.size() == 0; - if (use_default_perm) { - default_permute.resize(input_0_shape_rank); - for (size_t i = 0; i < input_0_shape_rank; ++i) { - default_permute[i] = gsl::narrow(input_0_shape_rank - 1 - i); - } - perm = &default_permute; - } else { - perm = &permute; - } - - tvm::Tensor Y = Transpose(inputs[0], ToTvmArrayInt(*perm), node.Name() + "_Transpose"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc deleted file mode 100644 index 9d6df7c1c430d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/where.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Transpose OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Where)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - tvm::Tensor Y = Where(inputs[0], inputs[1], inputs[2], node.Name() + "_Where"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc deleted file mode 100644 index 7889e2add755e..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/passes/op_ir_creator/all_ops.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -TVMIRBuilder::TVMIRBuilder(const std::string& name) - : name_(name) {} - -const std::string& TVMIRBuilder::Name() const { - return name_; -} - -void TVMIRBuilder::InsertDispatcher(std::unique_ptr&& ptr) { - dispatchers_.push_back(std::move(ptr)); -} - -void TVMIRBuilder::ClearAllDispatchers() { - dispatchers_.clear(); -} - -void TVMIRBuilder::DumpAllOpIRCreators() const { - int count = 0; - for (auto& d : dispatchers_) { - std::cout << "************ TVM OpIRDispatcher " - << count << " : " - << d->Name() - << " ************" << std::endl; - - d->ForEach([](const std::string& key, OpIRCreator* builder) { - std::cout << "Key " << key - << ", Creator " << builder->Name() << std::endl; - }); - - ++count; - } -} - -// Evaluate finds ONE proper OpIRCreator and build the corresponding OpIR -// If a TVMIRBuilder has more than one OpIRCreator for an ORT Op, -// the first one will be used. -// Please adjust registration order and dispatcher in TVMIRBuilder -// to make sure the proper OpIRCreator is called. -Status TVMIRBuilder::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - OpIRCreator* candidate = nullptr; - for (auto& d : dispatchers_) { - candidate = d->Find(node); - if (nullptr != candidate) - break; - } - - if (nullptr == candidate) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented: ", node.OpType()); - } - - ORT_RETURN_IF_ERROR(candidate->Evaluate(inputs, node, ctx_codegen, outputs)); - - return Status::OK(); -} - -// BEGIN: Generic IR creator classes -#define ADD_OP_ITEM(name) \ - op_ir_registry->Register(std::make_unique()); - -#define BINARY_OP(name) ADD_OP_ITEM(name) -#define BINARY_CMP_OP(name) ADD_OP_ITEM(name) -#define POOL_OP(name) ADD_OP_ITEM(name) -#define REDUCE_OP(name) ADD_OP_ITEM(name) -#define REDUCE_INDEXED_OP(name) ADD_OP_ITEM(name) -#define UNARY_OP(name) ADD_OP_ITEM(name) -#define VARIADIC_OP(name) ADD_OP_ITEM(name) - -void RegisterAllGenericOpIRCreators(OpIRRegistry* op_ir_registry) { - LIST_ALL_GENERIC_OPS(); -} - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -#undef VARIADIC_OP - -// BEGIN: Plugin Generic IR creator classes -#define ADD_OP_ITEM(name) \ - dispatcher->Register(#name, registry->Get(GENERIC_OP_IR_CREATOR_STRING(name))); - -#define BINARY_OP(name) ADD_OP_ITEM(name) -#define BINARY_CMP_OP(name) ADD_OP_ITEM(name) -#define POOL_OP(name) ADD_OP_ITEM(name) -#define REDUCE_OP(name) ADD_OP_ITEM(name) -#define REDUCE_INDEXED_OP(name) ADD_OP_ITEM(name) -#define UNARY_OP(name) ADD_OP_ITEM(name) -#define VARIADIC_OP(name) ADD_OP_ITEM(name) - -void RegisterGenericOrtOpTypeDispatcher(const std::shared_ptr& builder, - const OpIRRegistry* registry) { - auto dispatcher = std::make_unique("GenericOrtOpTypeOpIRCreators"); - LIST_ALL_GENERIC_OPS() - builder->InsertDispatcher(std::move(dispatcher)); -} - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -// END: Generic IR creators classes - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h deleted file mode 100644 index c80056e619d6d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// TVMIRBuilder contains all applicable TVM OpIRCreators -// OpIRCreators are stored in multiple dispatchers -// that check different conditions of an ORT Node. - -// If an ORT Node satisfies more than one OpIRCreators, -// the first dispatched pass will be applied. - -class TVMIRBuilder { - public: - TVMIRBuilder(const std::string& name); - ~TVMIRBuilder() = default; - - // A debug dumps all existing in this TVMIRBuilders - void DumpAllOpIRCreators() const; - - // Evaluates an OpIRCreator that first satisfies condtions of all dispatchers - Status Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx, - tvm::Array& outputs); - - // Inserts a dispatcher and move its ownership to this TVMIRBuilder - void InsertDispatcher(std::unique_ptr&& ptr); - - // Clears all dispatchers in this TVMIRBuilder - void ClearAllDispatchers(); - - // Dumps the name of this TVMIRBuilder - const std::string& Name() const; - - private: - std::vector> dispatchers_; - std::string name_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMIRBuilder); -}; - -// Utility function to register all builtin generic OpIRCreators into an OpIRRegistry. -// It creates instances of all generic OpIRCreators -// and registers them to op_ir_registry -void RegisterAllGenericOpIRCreators(OpIRRegistry* op_ir_registry); - -// Utility function to bind all builtin generic OpIRCreators to a TVMIRBuilder. -// It creates an instance of a Dispatcher that contains all generic OpIRCreators created above -// and uses OrtOpType to dispatch OpIRCreators. -// Then, it registers the created Dispatcher to a TVMIRBuilder, builder. -void RegisterGenericOrtOpTypeDispatcher(const std::shared_ptr& builder, - const OpIRRegistry* registry); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc deleted file mode 100644 index 992272753f5a4..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace codegen { -// Explicit instantiation for OpIRCreator -template class CreatorBase&, - const Node&, - tvm_codegen::CodeGenContext&, - tvm::Array&, - Status>; - -// Explicit instantiation for OpIRCreators' dispatcher -template class DispatcherBase; - -} // namespace codegen - -namespace tvm_codegen { - -// One dispatcher is based on ORT OpType -OpIRCreator* OP_IR_DISPATCHER_CLASS(OpType)::Find(const Node& node) { - return DispatcherBase::Get(node.OpType()); -} - -// Another dispatcher is based ORT NodeArg name (GetKey) -OpIRCreator* OP_IR_DISPATCHER_CLASS(NodeName)::Find(const Node& node) { - return DispatcherBase::Get(GetKey(&node)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h deleted file mode 100644 index e29c4a9f20767..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/creator.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/common/registry.h" -#include "core/graph/graph.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -class CodeGenContext; - -// OpIRCreator lowers an Ort Node to its corresponding TVM IRs -using OpIRCreator = codegen::CreatorBase< - const tvm::Array&, - const Node&, - CodeGenContext&, - tvm::Array&, - Status>; - -// OpIRDispatcher is the base dispatcher for TVM IR Builder -// It checks whether an Ort Node satisfying a criteria (in Find) -// and dispatches a corresponding OpIRCreator. -class OpIRDispatcher : public codegen::DispatcherBase { - public: - OpIRDispatcher(const std::string& name) - : DispatcherBase(name) {} - - virtual ~OpIRDispatcher() = default; - - virtual OpIRCreator* Find(const Node&) = 0; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpIRDispatcher); -}; - -// Macro returns an OpIRCreators' dispatcher's name -#define OP_IR_DISPATCHER_CLASS(OP) \ - TVM##OP##IRCreator - -// Macro declares an OpIRCreators' dispatcher -#define DECLARE_OP_IR_DISPATCHER_CLASS(OP) \ - class OP_IR_DISPATCHER_CLASS(OP) : public OpIRDispatcher { \ - public: \ - TVM##OP##IRCreator(const std::string& name) \ - : OpIRDispatcher(name) {} \ - ~TVM##OP##IRCreator() = default; \ - OpIRCreator* Find(const Node&) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OP_IR_DISPATCHER_CLASS(OP)); \ - }; - -// Declare two common dispatchers for TVM Op IR builders -// One dispatcher is based on Ort OpType -DECLARE_OP_IR_DISPATCHER_CLASS(OpType) -// Another dispatcher is based Ort NodeArg name -DECLARE_OP_IR_DISPATCHER_CLASS(NodeName) - -// OpIRCreator Registry is a registry holds all OpIRCreators -using OpIRRegistry = codegen::RegistryBase; - -// Macro declares an OpIRCreator -#define DECLARE_OP_IR_CREATOR_CLASS(OP, PREFIX) \ - DECLARE_CREATOR_CLASS(OP, PREFIX##IRCreator, \ - const tvm::Array&, \ - const Node&, \ - tvm_codegen::CodeGenContext&, \ - tvm::Array&, \ - Status) - -// Macro returns an OpIRCreator's name with prefix -#define OP_IR_CREATOR_CLASS_EX(OP, PREFIX, ARCH) \ - CREATOR_CLASS(OP, PREFIX##ARCH##IRCreator) - -// Macro declares an OpIRCreator with prefix and arch -#define DECLARE_OP_IR_CREATOR_CLASS_EX(OP, PREFIX, ARCH) \ - DECLARE_OP_IR_CREATOR_CLASS(OP, PREFIX##ARCH) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/all_schedules.h b/onnxruntime/core/codegen/passes/scheduler/all_schedules.h deleted file mode 100644 index fe4be90f9fc84..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/all_schedules.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/scheduler/tvm_scheduler.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// AlwaysRoot is for debug purpose -DECLARE_TVM_SCHEDULER_CLASS(AlwaysRoot, GenericTVMRule) -// Create schedule for TVM Rule -DECLARE_TVM_SCHEDULER_CLASS(Extern, GenericTVMRule) -DECLARE_TVM_SCHEDULER_CLASS(Reduce, GenericTVMRule) - -// Crete scheduler for ORT OpType, Softmax -DECLARE_TVM_SCHEDULER_CLASS(Softmax, GenericOrtOpType) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc b/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc deleted file mode 100644 index 59f492d164b14..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/all_schedules.h" - -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -bool TVM_SCHEDULER_CLASS(Softmax, GenericOrtOpType)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - // compute root the exp since it is reused more than once - auto& tensor_exp = tensor->op->InputTensors()[0]; - return InsertRootSchedule(tensor_exp, ctx_sched); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc b/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc deleted file mode 100644 index 76c2ad509c401..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/utils.h" -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Check the schedule of tensor -// If it has no compute_root, Insert compute_root to tensor, and record it to ctx.scheduled_tensors -bool InsertRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure || - it->second == ScheduleType::ScheduleRoot) { - return false; - } - it->second = ScheduleType::ScheduleRoot; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleRoot)); - } - ctx.schedule[tensor->op].compute_root(); - return true; -} - -// Check the schedule of tensor -// If it is not labeled as closure, lable it. -bool InsertClosure(const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure) - return false; - it->second = ScheduleType::ScheduleClosure; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleClosure)); - } - return true; -} - -// Combination of InsertRootSchedule and InsertClosure -bool InsertRootScheduleAndClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure) { - return false; - } - it->second = ScheduleType::ScheduleClosure; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleClosure)); - } - ctx.schedule[tensor->op].compute_root(); - return true; -} - -// Check precondition for vectorize schedule -bool ShouldTryVectorization( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second > ScheduleType::ScheduleInline) { - return false; - } - } - return true; -} - -// Check the schedule of tensor -// If it is not scheduled, try to vectorize it. -// Note TryVectorization has to use with compute_root. -// Therefore, there is a safety check of tensor's schedule -bool TryVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx) { - if (!ShouldTryVectorization(tensor, ctx)) - return false; - - auto shape = tensor->shape; - auto rank = shape.size(); - if (rank < 1) { - return false; - } - const int64_t* tail_dim = as_const_int(shape[rank - 1]); - - if (nullptr != tail_dim) { - auto extern_op = tensor->op.as(); - if (nullptr != extern_op) { - return false; - } - - auto compute_op = tensor->op.as(); - - if (nullptr != compute_op) { - auto axis = compute_op->axis; - tvm::IterVar x = axis[rank - 1]; - if ((*tail_dim) > natural_vector_size) { - if ((*tail_dim) % natural_vector_size != 0) { - natural_vector_size = GCD(natural_vector_size, (*tail_dim)); - } - - if (natural_vector_size > 1) { - tvm::IterVar xi, xo; - ctx.schedule[tensor->op].split(x, static_cast(natural_vector_size), &xo, &xi); - ctx.schedule[tensor->op].vectorize(xi); - return true; - } - } else if (*tail_dim > 0) { - // don't vectorize if dim is 0 - ctx.schedule[tensor->op].vectorize(x); - return true; - } - } - } - return false; -} - -// Check the schedule of tensor -// If it is not scheduled, try to add compute_inline on it. -// Note TryInlineSchedule cannot be used with compute_root. -// Therefore, there is a safety check of tensor's schedule. -bool TryInlineSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if ((int)it->second < (int)ScheduleType::ScheduleInline) { - ctx.schedule[tensor->op].compute_inline(); - it->second = ScheduleType::ScheduleInline; - return true; - } else { - return false; - } - } - ctx.schedule[tensor->op].compute_inline(); - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleInline)); - return true; -} - -// Check the schedule of tensor's inputs, and call InsertRootSchedule for each of them -bool InputRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - bool status = false; - for (auto& t : tensor->op->InputTensors()) { - if (t->op->InputTensors().size() > 0) { - bool status_root = InsertRootSchedule(t, ctx); - status = status || status_root; - } - } - return status; -} - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule and TryVectorization for each of them -bool InputRootScheduleWithVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx) { - bool status = false; - for (auto& t : tensor->op->InputTensors()) { - if (t->op->InputTensors().size() > 0) { - bool status_vec = TryVectorization(t, natural_vector_size, ctx); - bool status_root = InsertRootSchedule(t, ctx); - status = status || status_root || status_vec; - } - } - return status; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h b/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h deleted file mode 100644 index 4a0781f94d385..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Check the schedule of tensor -// If it has no compute_root, Insert compute_root to tensor, -// and record it to ctx.scheduled_tensors -bool InsertRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not labeled as closure, lable it. -bool InsertClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Combination of InsertRootSchedule and InsertClosure -bool InsertRootScheduleAndClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check precondition for vectorize schedule -bool ShouldTryVectorization( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not scheduled, try to vectorize it. -// Note TryVectorization has to use with compute_root. -// Therefore, there is a safety check of tensor's schedule -bool TryVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not scheduled, try to add compute_inline on it. -// Note TryInlineSchedule cannot be used with compute_root. -// Therefore, there is a safety check of tensor's schedule. -bool TryInlineSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule for each of them -bool InputRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule and TryVectorization for each of them -bool InputRootScheduleWithVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc deleted file mode 100644 index 33162deddc983..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/all_schedules.h" - -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// This is for debug -bool TVM_SCHEDULER_CLASS(AlwaysRoot, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - return InsertRootSchedule(tensor, ctx_sched); -} - -// For External tvm::Tensor -bool TVM_SCHEDULER_CLASS(Extern, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - bool status = InsertRootScheduleAndClosure(tensor, ctx_sched); - bool status_input = InputRootSchedule(tensor, ctx_sched); - return status || status_input; -} - -// For Reduce Compute tvm::Tensor -bool TVM_SCHEDULER_CLASS(Reduce, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - return InsertRootScheduleAndClosure(tensor, ctx_sched); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc deleted file mode 100644 index 2c8250198fa5f..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/tvm_schedule_builder.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/common/settings.h" -#include "core/common/common.h" -#include "core/common/logging/logging.h" - -namespace onnxruntime { -namespace tvm_codegen { - -TVMScheduleBuilder::TVMScheduleBuilder(const std::string& name) - : name_(name) { -} - -const std::string& TVMScheduleBuilder::Name() const { - return name_; -} - -void TVMScheduleBuilder::InsertDispatcher(std::unique_ptr&& ptr) { - dispatchers_.push_back(std::move(ptr)); -} - -void TVMScheduleBuilder::ClearDispatcher() { - dispatchers_.clear(); -} - -void TVMScheduleBuilder::DumpAllSchedulers() const { - std::ostringstream stream; - int count = 0; - stream << "[CODEGEN_DUMP_SCHEDULE]" << std::endl; - for (auto& d : dispatchers_) { - stream << "************ TVM Scheduler Dispatcher " - << count << " : " - << d->Name() - << " ************" << std::endl; - - d->ForEach([&stream](const std::string& key, Scheduler* op) { - stream << "Key " << key - << ", Creator " << op->Name() << std::endl; - }); - - ++count; - } - - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); -} - -Status TVMScheduleBuilder::Evaluate( - const tvm::Tensor& tensor, - const Node* node, - CodeGenContext& ctx_codegen, - ScheduleContext& sched) { - Scheduler* candidate = nullptr; - - for (auto& d : dispatchers_) { - candidate = d->Find(tensor, node, ctx_codegen); - if (nullptr != candidate) - break; - } - - bool enable_dump_schedule = codegen::CodeGenSettings::Instance().HasOption(codegen::CodeGenSettings::kCodeGenDumpSchedule); - - if (nullptr == candidate) { - if (nullptr != node) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented: ", node->OpType()); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented an internal tvm::Tensor: ", tensor->op->name); - } - - bool status = candidate->Evaluate(tensor, node, ctx_codegen, sched); - - if (enable_dump_schedule) { - std::ostringstream stream; - if (nullptr != node) { - stream << std::endl; - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule Node: " << node->Name() << std::endl; - } else { - stream << std::endl; - } - - if (status) { - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule tvm::Tesnor " - << tensor->op->name - << " with " - << candidate->Name() << std::endl; - } else { - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule tvm::Tesnor " - << tensor->op->name - << " is suppressed " << std::endl; - } - - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); - } - - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h b/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h deleted file mode 100644 index 9f0a1b3ef45c2..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/scheduler/tvm_scheduler.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// TVMScheduleBuilder contains all applicable TVM scheduler passes. -// Scheduler passes are stored in multiple dispatchers -// that check different conditions of a tvm::Tensor. - -// If a tvm::Tensor satisfies more than one TVM scheduler passes, -// the first dispatched pass will be applied. - -class TVMScheduleBuilder { - public: - // TODO: add more parameter in consructor to support different target - TVMScheduleBuilder(const std::string& name); - ~TVMScheduleBuilder() = default; - - void DumpAllSchedulers() const; - - Status Evaluate( - const tvm::Tensor& tensor, - const Node* node, - CodeGenContext& ctx, - ScheduleContext& sched); - - void InsertDispatcher(std::unique_ptr&& ptr); - void ClearDispatcher(); - - const std::string& Name() const; - - private: - std::vector> dispatchers_; - std::string name_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMScheduleBuilder); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc deleted file mode 100644 index 071200a234e33..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/tvm_scheduler.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace codegen { -// explicit instantiation -template class CreatorBase; - -template class DispatcherBase; - -} // namespace codegen - -namespace tvm_codegen { - -static const std::string TMVOpRuleKey_Extern("TVMOpRule_Extern"); -static const std::string TMVOpRuleKey_ComputeReduce("TVMOpRule_ComputeReduce"); -static const std::string TMVOpRuleKey_ComputeRegular("TVMOpRule_ComputeRegular"); -static const std::string TMVOpRuleKey_AlwaysRoot("TMVOpRuleKey_AlwaysRoot"); -static const std::string TMVOpRuleKey_NoRule("TVMOpRule_NoRule"); - -const std::string& GetTVMOpRule(TVMOpRuleType rule) { - if (rule == TVMOpRuleType::Extern) { - return TMVOpRuleKey_Extern; - } else if (rule == TVMOpRuleType::ComputeReduce) { - return TMVOpRuleKey_ComputeReduce; - } else if (rule == TVMOpRuleType::AlwaysRoot) { - return TMVOpRuleKey_AlwaysRoot; - } - return TMVOpRuleKey_NoRule; -} - -const std::string& GetTVMOpRule(const tvm::Tensor& tensor) { - auto extern_op = tensor->op.as(); - - if (nullptr != extern_op) { - return TMVOpRuleKey_Extern; - } - - auto compute_op = tensor->op.as(); - if (nullptr != compute_op) { - if (compute_op->reduce_axis.size() > 0) { - return TMVOpRuleKey_ComputeReduce; - } - } - - return TMVOpRuleKey_NoRule; -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(OrtOpType):: - Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext&) { - if (nullptr == node) - return nullptr; - return DispatcherBase::Get(node->OpType()); -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(TVMOpRule):: - Find(const tvm::Tensor& tensor, const Node*, tvm_codegen::CodeGenContext&) { - return DispatcherBase::Get(GetTVMOpRule(tensor)); -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(OrtOpName):: - Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext&) { - if (nullptr == node) - return nullptr; - return DispatcherBase::Get(GetKey(node)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h deleted file mode 100644 index d022497c77f7e..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/codegen/common/creator.h" -#include "core/codegen/common/registry.h" -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/graph/graph.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// These are current generic TVMOpRule we used. -enum class TVMOpRuleType : int { - Extern = 0, - ComputeReduce = 1, - ComputeRegular = 2, - AlwaysRoot = 3, // for debug - NoRule, -}; - -const std::string& GetTVMOpRule(const tvm::Tensor& tensor); -const std::string& GetTVMOpRule(TVMOpRuleType rule); - -// These are current generic ScheduleType in tvm_codegen -enum class ScheduleType : int { - ScheduleNone = 0, - ScheduleInline = 1, - ScheduleAt = 2, - ScheduleRoot = 3, - ScheduleClosure = 4, -}; - -// Data struct to bundle tvm::Schedule and scheduled tensor -struct ScheduleContext { - ScheduleContext(const tvm::Array& ops) { - schedule = tvm::create_schedule(ops); - } - tvm::Schedule schedule; - std::map scheduled_tensors; -}; - -// Scheduler inserts a tvm::Schedule content to a tvm::Tensor -using Scheduler = codegen::CreatorBase< - const tvm::Tensor&, - const Node*, - tvm_codegen::CodeGenContext&, - ScheduleContext&, - bool>; - -// TVMScheduleDispatcher is the base dispatcher for TVM Schedule Builder -// It checks whether a pair of {tvm::Tensor, Ort Node} satisfying a criteria (in Find) -// and dispatches a corresponding Scheduler. -class TVMScheduleDispatcher : public codegen::DispatcherBase { - public: - TVMScheduleDispatcher(const std::string& name) - : DispatcherBase(name) {} - - virtual ~TVMScheduleDispatcher() = default; - - virtual Scheduler* Find(const tvm::Tensor&, - const Node*, - tvm_codegen::CodeGenContext&) = 0; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMScheduleDispatcher); -}; - -// Macro returns an Schedulers' dispatcher's name -#define SCHEDULE_DISPATCHER_CLASS(TYPE) \ - TVM##TYPE##Schedulers - -// Macro declares an Schedulers' dispatcher -#define DECLARE_SCHEDULE_DISPATCHER_CLASS(TYPE) \ - class SCHEDULE_DISPATCHER_CLASS(TYPE) : public tvm_codegen::TVMScheduleDispatcher { \ - public: \ - TVM##TYPE##Schedulers(const std::string& name) \ - : TVMScheduleDispatcher(name) {} \ - ~TVM##TYPE##Schedulers() = default; \ - tvm_codegen::Scheduler* Find(const tvm::Tensor&, \ - const Node*, \ - tvm_codegen::CodeGenContext&) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVM##TYPE##Schedulers); \ - }; - -// Common dispatchers are listed here -// For a special pattern, it can be created later. -// One dispatcher is based on Ort OpType -DECLARE_SCHEDULE_DISPATCHER_CLASS(OrtOpType) -// One dispatcher is based on TVMOpRule -DECLARE_SCHEDULE_DISPATCHER_CLASS(TVMOpRule) -// One dispatcher is based Ort NodeArg name -DECLARE_SCHEDULE_DISPATCHER_CLASS(OrtOpName) - -// Scheduler Registry is a registry holds all Schedulers -using TVMScheduleRegistry = codegen::RegistryBase; - -// Macro declares TVM scheduler class -#define DECLARE_TVM_SCHEDULER_CLASS(OP, PRETFIX) \ - DECLARE_CREATOR_CLASS(OP, PRETFIX##Scheduler, \ - const tvm::Tensor&, \ - const Node*, \ - tvm_codegen::CodeGenContext&, \ - tvm_codegen::ScheduleContext&, \ - bool) - -// Macro returns TVM scheduler's name with prefix -#define TVM_SCHEDULER_CLASS(OP, PREFIX) \ - CREATOR_CLASS(OP, PREFIX##Scheduler) - -// Macro returns TVM scheduler's name as string -#define TVM_SCHEDULER_STRING(OP, PREFIX) \ - STRINGIZE(TVM_SCHEDULER_CLASS(OP, PREFIX)) - -// Macro returns TVM scheduler's name with prefix and arch -#define TVM_SCHEDULER_CLASS_EX(OP, PREFIX, ARCH) \ - CREATOR_CLASS(OP, PREFIX##ARCH##Scheduler) - -// Macro declares TVM scheduler class with prefix and arch -#define DECLARE_TVM_SCHEDULER_CLASS_EX(OP, PREFIX, ARCH) \ - DECLARE_TVM_SCHEDULER_CLASS(OP, PREFIX##ARCH) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/codegen_context.cc b/onnxruntime/core/codegen/passes/utils/codegen_context.cc deleted file mode 100644 index 2f1a59b4a92eb..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/codegen_context.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/utils/codegen_context.h" - -#include "core/codegen/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -CodeGenContext::CodeGenContext( - const codegen::CodeGenHandle* handle) - : handle_(handle), unname_symbol_counter_(0) {} - -tvm::Var CodeGenContext::GetOrCreateDynamicDim(const std::string& name) { - if (dynamic_dims_.count(name) == 0) - dynamic_dims_.emplace(name, tvm::Var(name)); - - return dynamic_dims_.at(name); -} - -std::string CodeGenContext::CreateUnnamedSymbol() { - return "unnamed_" + std::to_string(unname_symbol_counter_++); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/codegen_context.h b/onnxruntime/core/codegen/passes/utils/codegen_context.h deleted file mode 100644 index 641552bd3b2e8..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/codegen_context.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/handle.h" -#include "core/codegen/common/common.h" -#include "core/common/common.h" -#include "core/framework/data_types.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// CodeGenContext is a data structure involving across passes -// Compiler developers can use it to store meta data -// to support fine-grained control of code generation -class CodeGenContext { - public: - CodeGenContext(const codegen::CodeGenHandle* handle); - - virtual ~CodeGenContext() = default; - - // returns tvm::Var for the dynamic dim - tvm::Var GetOrCreateDynamicDim(const std::string& name); - - const codegen::CodeGenHandle* GetCodeGenHandle() const { - return handle_; - } - - std::string CreateUnnamedSymbol(); - - protected: - std::unordered_map dynamic_dims_; - - const codegen::CodeGenHandle* handle_; - - int unname_symbol_counter_; -}; - -// Add Promote for CodeGenContext -DYNAMIC_PROMOTE(CodeGenContext) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc deleted file mode 100644 index 55892974aa33f..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/utils/ort_tvm_utils.h" - -#include "core/codegen/common/profile.h" -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/common.h" -#include - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -#define RETURN_DLDATATYPE_IF_MATCH(type_enum, type, type_code) \ - case type_enum: \ - return {type_code, sizeof(type) * 8, 1}; \ - break; - -// DLDataType: {DLDataTypeCode, bits, lanes} -DLDataType ToTvmDLDataType(MLDataType ml_type) { - if (ml_type->IsTensorType()) { - ml_type = ml_type->AsTensorType()->GetElementType(); - } - auto prim_type = ml_type->AsPrimitiveDataType(); - if (prim_type == nullptr) { - ORT_NOT_IMPLEMENTED("converting MLDataType ", ml_type, " to tvm DLDataType is not implemented"); - } - - switch (prim_type->GetDataType()) { - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT8, int8_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT8, uint8_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT16, int16_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT16, uint16_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT32, int32_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT32, uint32_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT64, int64_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT64, uint64_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_BOOL, bool, kDLUInt); - - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, float, kDLFloat); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, double, kDLFloat); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, MLFloat16, kDLFloat); - default: - ORT_NOT_IMPLEMENTED("converting MLDataType ", ml_type, " to tvm DLDataType is not implemented"); - } -} - -tvm::Type ToTvmType(ONNX_NAMESPACE::TensorProto_DataType proto_type) { - switch (proto_type) { - // Note that bool is uint1 in tvm, but uint8 in ONNX, so it always require special handling - // case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - // return tvm::UInt(1); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT16: - return tvm::Int(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return tvm::Int(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return tvm::Int(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - return tvm::UInt(8); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: - return tvm::UInt(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - return tvm::UInt(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - return tvm::UInt(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return tvm::Float(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return tvm::Float(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - return tvm::Int(8); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - return tvm::Float(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_STRING: - ORT_THROW("Casting to and from strings is not supported yet."); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: - ORT_THROW("Cast op must have 'to' argument of type DataType"); /*break;*/ - default: - ORT_THROW("Unexpected 'to' argument value: ", proto_type); - } -} - -tvm::Array ShapeToTvmArray(const NodeArg* def, CodeGenContext& ctx) { - ORT_ENFORCE(nullptr != def); - const ONNX_NAMESPACE::TensorShapeProto* shape_proto = def->Shape(); - ORT_ENFORCE(nullptr != shape_proto); - - tvm::Array arr; - for (int i = 0; i < shape_proto->dim_size(); ++i) { - arr.push_back(ShapeDimToTvmDim(shape_proto->dim(i), ctx)); - } - return arr; -} - -tvm::Expr ShapeDimToTvmDim(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim, CodeGenContext& ctx) { - if (utils::HasDimParam(dim)) { - return ctx.GetOrCreateDynamicDim(dim.dim_param()); - } else if (utils::HasDimValue(dim)) { - return tvm::Expr(gsl::narrow_cast(dim.dim_value())); - } - return ctx.GetOrCreateDynamicDim(ctx.CreateUnnamedSymbol()); -} - -#ifdef CODEGEN_ENABLE_PROFILER -struct event_in_bracket_and_id { - bool in_bracket; - size_t id; -}; -std::unordered_map g_codegen_profiler_event_ids; -std::vector> g_codegen_profiler_events(1024); - -TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.profile_event") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { - DLTensor* X = args[0]; - DLTensor* Y = args[1]; - size_t event_id = args[2]; - bool is_begin = args[3]; - if (!is_begin) { - DCHECK(event_id < g_codegen_profiler_event_ids.size()); - profiling::Profiler::Instance().EndTimeAndRecordEvent( - profiling::EventCategory::NODE_EVENT, - g_codegen_profiler_events[event_id].first, - g_codegen_profiler_events[event_id].second); - } - - { - CODEGEN_PROFILER_EVENT("profile_stub"); - int64_t elem_count = 1; - for (int i = 0; i < X->ndim; ++i) { - elem_count *= X->shape[i]; - } - // there's overhead in this copy, so put begin after copy and end before copy - memcpy(static_cast(Y->data) + Y->byte_offset, - static_cast(X->data) + X->byte_offset, - elem_count * X->dtype.bits / 8); - } - - if (is_begin) { - DCHECK(g_codegen_profiler_events.size() > event_id); - DCHECK(!g_codegen_profiler_events[event_id].first.empty()); - DCHECK(g_codegen_profiler_event_ids[g_codegen_profiler_events[event_id].first].id == event_id); - g_codegen_profiler_events[event_id].second = - profiling::Profiler::Instance().StartTime(); - } - }); - -tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name) { - size_t event_id; - if (g_codegen_profiler_event_ids.count(event_name) == 0) { - event_id = g_codegen_profiler_event_ids.size(); - ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); - } else { - ORT_ENFORCE(!g_codegen_profiler_event_ids[event_name].in_bracket); - event_id = g_codegen_profiler_event_ids[event_name].id; - } - g_codegen_profiler_event_ids[event_name] = {true, event_id}; - g_codegen_profiler_events[event_id].first = event_name; - return topi::detail::make_extern( - {X->shape}, {X->dtype}, {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0]), - gsl::narrow(event_id), - true}); - }, - event_name + "_begin", "", {})[0]; -} - -tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name) { - ORT_ENFORCE(g_codegen_profiler_event_ids.at(event_name).in_bracket); - g_codegen_profiler_event_ids.at(event_name).in_bracket = false; - size_t event_id = g_codegen_profiler_event_ids.at(event_name).id; - ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); - ORT_ENFORCE(g_codegen_profiler_events[event_id].first == event_name); - return topi::detail::make_extern( - {X->shape}, {X->dtype}, {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0]), - gsl::narrow(event_id), - false}); - }, - event_name + "_end", "", {})[0]; -} -#endif - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h deleted file mode 100644 index f13e91a2d5cea..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/framework/data_types.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -class CodeGenContext; - -// Helper function that converts a onnxruntime MLDataType to TVM DLDataType -DLDataType ToTvmDLDataType(MLDataType ml_type); - -tvm::Type ToTvmType(ONNX_NAMESPACE::TensorProto_DataType proto_type); - -tvm::Array ShapeToTvmArray(const NodeArg* def, CodeGenContext& ctx); - -tvm::Expr ShapeDimToTvmDim(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim, CodeGenContext& ctx); - -#ifdef CODEGEN_ENABLE_PROFILER -// Helper functions to inspect into lowered function -tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name); - -tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name); -#endif - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc deleted file mode 100644 index c65132f6d4bca..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/tiling_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_name_prefix = "tiling_2d_"; -constexpr int num_bits = 8; - -const std::string WeightLayoutTiling2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width) { - return WeightLayout::GetKey( - local_name_prefix + std::to_string(vector_width), - proto_type, 2, 0.0f); -} - -WeightLayoutTiling2D::WeightLayoutTiling2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width) - : WeightLayout( - local_name_prefix + std::to_string(vector_width), - proto_type, 2, 0.0f), - vector_width_(vector_width) {} - -CoordTransFunc WeightLayoutTiling2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -CoordTransFunc WeightLayoutTiling2D::ToNominal(const tvm::Tensor& X) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 2); - ORT_ENFORCE(X->dtype == HalideIR::type_of() || - X->dtype == HalideIR::type_of()); - - int tile_row = (sizeof(int32_t) * num_bits) / X->dtype.bits(); - int tile_col = ((vector_width_ * num_bits) / X->dtype.bits()) / tile_row; - - const auto& x = actual_coord[0]; - const auto& y = actual_coord[1]; - - const int block_dimy = tile_row; - const int block_dimx = tile_col; - - const auto& y0 = y % block_dimy; - const auto& y1 = (y / block_dimy) % block_dimx; - const auto& y2 = y / block_dimy / block_dimx; - - const auto& x0 = x % block_dimx; - const auto& x1 = x / block_dimx; - - return tvm::Array{ - y0 + y2 * block_dimx * block_dimy + x0 * block_dimy, - y1 + x1 * block_dimx}; - }; -} - -tvm::Array WeightLayoutTiling2D::ToActualShape(const tvm::Tensor& X) const { - ORT_ENFORCE(X->dtype == HalideIR::type_of() || - X->dtype == HalideIR::type_of()); - - auto pad_row = tvm::make_const(tvm::Int(32), (vector_width_ * num_bits) / X->dtype.bits()); - auto pad_col = tvm::make_const(tvm::Int(32), vector_width_ / sizeof(int32_t)); - - auto new_shape0 = ((X->shape[1] + pad_col - 1) / pad_col) * pad_col; - auto new_shape1 = ((X->shape[0] + pad_row - 1) / pad_row) * pad_row; - - tvm::Array - new_shape = { - new_shape0, - new_shape1}; - return new_shape; -} - -std::vector WeightLayoutTiling2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - ORT_ENFORCE(X->Shape().GetDims().size() == 2); - - int pad_row = vector_width_ / X->DataType()->Size(); - int pad_col = vector_width_ / sizeof(int32_t); - - auto old_shape = X->Shape().GetDims(); - auto new_shape0 = (old_shape[1] + pad_col - 1) / pad_col * pad_col; - auto new_shape1 = ((old_shape[0] + pad_row - 1) / pad_row) * pad_row; - - std::vector new_shape = { - new_shape0, - new_shape1}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h b/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h deleted file mode 100644 index 64334a069f94f..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -/* - * \class! WeightLayoutTiling2D - * \breif! Transform 2D weight to 4D by tiling both dimension, - * this layout is used for tensorization. - * [M, N] => [M/Tx, N/Ty, Tx, Ty] - */ - -class WeightLayoutTiling2D : public WeightLayout { - public: - static const std::string GetKey(ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width); - - public: - WeightLayoutTiling2D(ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width); - - ~WeightLayoutTiling2D() = default; - - CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - int vector_width_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutTiling2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc deleted file mode 100644 index ea8597f7dd89d..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/transpose_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_layout_name = "transpose_2d"; - -const std::string WeightLayoutTranspose2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type) { - return WeightLayout::GetKey(local_layout_name, proto_type, 2, 0.0f); -} - -WeightLayoutTranspose2D::WeightLayoutTranspose2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type) - : WeightLayout(local_layout_name, proto_type, 2, 0.0f) {} - -CoordTransFunc WeightLayoutTranspose2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -CoordTransFunc WeightLayoutTranspose2D::ToNominal(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 2); - const auto& y = actual_coord[0]; - const auto& x = actual_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -tvm::Array WeightLayoutTranspose2D::ToActualShape(const tvm::Tensor& X) const { - tvm::Array new_shape = { - X->shape[1], - X->shape[0]}; - return new_shape; -} - -std::vector WeightLayoutTranspose2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - ORT_ENFORCE(X->Shape().GetDims().size() == 2); - auto old_shape = X->Shape().GetDims(); - - std::vector new_shape = { - old_shape[1], - old_shape[0]}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h b/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h deleted file mode 100644 index 65babaaec8dac..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// WeightLayoutTranspose2D for transposing a 2D weight -// [W, H] => [H, W] -class WeightLayoutTranspose2D : public WeightLayout { - public: - static const std::string GetKey(ONNX_NAMESPACE::TensorProto_DataType proto_type); - - public: - WeightLayoutTranspose2D(ONNX_NAMESPACE::TensorProto_DataType proto_type); - - ~WeightLayoutTranspose2D() = default; - - CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutTranspose2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc deleted file mode 100644 index b1ddb791a3b3d..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/vertical_stripes_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_name_prefix = "vertical_stripe_2d_"; - -const std::string WeightLayoutVerticalStripe2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width) { - return WeightLayout::GetKey( - local_name_prefix + std::to_string(stripe_width), - proto_type, 2, 0.0f); -} - -WeightLayoutVerticalStripe2D::WeightLayoutVerticalStripe2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width) - : WeightLayout( - local_name_prefix + std::to_string(stripe_width), - proto_type, 2, 0.0f), - stripe_width_(stripe_width) { -} - -CoordTransFunc WeightLayoutVerticalStripe2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x / stripe_width_, - y, - x % stripe_width_}; - }; -} - -CoordTransFunc WeightLayoutVerticalStripe2D::ToNominal(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 3); - const auto& z = actual_coord[0]; - const auto& y = actual_coord[1]; - const auto& x = actual_coord[2]; - return tvm::Array{ - y, - x + stripe_width_ * z}; - }; -} - -tvm::Array WeightLayoutVerticalStripe2D::ToActualShape(const tvm::Tensor& X) const { - tvm::Array new_shape = { - (X->shape[1] + stripe_width_ - 1) / stripe_width_, - X->shape[0], - stripe_width_}; - return new_shape; -} - -std::vector WeightLayoutVerticalStripe2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - auto old_shape = X->Shape().GetDims(); - - ORT_ENFORCE(old_shape.size() == 2); - - std::vector new_shape = { - (old_shape[1] + stripe_width_ - 1) / stripe_width_, - old_shape[0], - stripe_width_}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h b/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h deleted file mode 100644 index b9b65025dc014..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// WeightLayoutVerticalStripe2D for making a 2D weight to 3D, by tiling the lowest (verteical) dimension -// [W, H] => [H/stripe, W, stripe] -class WeightLayoutVerticalStripe2D : public WeightLayout { - public: - static const std::string GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width); - - public: - WeightLayoutVerticalStripe2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width); - - ~WeightLayoutVerticalStripe2D() = default; - - virtual CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - virtual CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - int stripe_width_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutVerticalStripe2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc deleted file mode 100644 index ab3e647fd284a..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/weight_layout.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/utils.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -static tvm::Tensor CreateTVMPlaceholder( - const std::string& name, - HalideIR::Type type, - int dim) { - tvm::Array shape; - if (dim > 0) { - for (int i = 0; i < dim; ++i) { - shape.push_back(tvm::Var(name + "_v" + std::to_string(i))); - } - } else { - shape.push_back(1); - } - return tvm::placeholder(shape, type, name + "_placeholder"); -} - -const std::string WeightLayout::GetKey( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero) { - std::ostringstream key; - key << name << "_type_" << static_cast(proto_type); - key << "_dim_" << input_dim; - key << "_pad_zero_" << pad_zero; - return NormalizeCppName(key.str()); -} - -WeightLayout::WeightLayout( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero) - : name_(GetKey(name, proto_type, input_dim, pad_zero)), - proto_type_(proto_type), - input_dim_(input_dim), - pad_zero_(pad_zero) {} - -const std::string& WeightLayout::Name() const { - return name_; -} - -void WeightLayout::CreateLayoutMarshallingTVMOp(tvm::Array& inputs, - tvm::Array& outputs) const { - HalideIR::Type halide_type = ToTvmType(proto_type_); - - tvm::Tensor placeholder = CreateTVMPlaceholder(name_, halide_type, input_dim_); - inputs.push_back(placeholder); - - tvm::Array new_shape = ToActualShape(placeholder); - CoordTransFunc new_coord_to_old_coord_func = ToNominal(placeholder); - tvm::Expr pad_zero_expr = tvm::make_const(halide_type, pad_zero_); - - tvm::Tensor output = tvm::compute( - new_shape, - [&](const tvm::Array& output_coord) { - tvm::Array output_coord1; - for (const auto& coord : output_coord) - output_coord1.push_back(coord); - auto input_coord = new_coord_to_old_coord_func(output_coord1); - ORT_ENFORCE(input_coord.size() == placeholder->shape.size()); - - if (input_coord.size() > 0) { - auto in_range = (input_coord[0] >= 0) && (input_coord[0] < placeholder->shape[0]); - for (size_t dim = 1; dim < input_coord.size(); ++dim) - in_range = in_range && (input_coord[dim] >= 0) && (input_coord[dim] < placeholder->shape[dim]); - - return tvm::if_then_else(in_range, placeholder(input_coord), pad_zero_expr); - } else { - // scalar - return placeholder(input_coord); - } - }); - - outputs.push_back(output); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h deleted file mode 100644 index 1b45a38e7e24e..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/codegen/common/registry.h" -#include "core/common/common.h" -#include "core/framework/tensor.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -using CoordTransFunc = std::function(const tvm::Array&)>; - -// WeightLayout is data layout transformer for weight/initializer -class WeightLayout { - public: - // Static function to return unique string as a key - static const std::string GetKey( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero); - - public: - WeightLayout( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero); - - virtual ~WeightLayout() = default; - - // Return a CoordTransFunc from actual (transformed) coordinate to normial (original) coordinate - virtual CoordTransFunc ToNominal(const tvm::Tensor& X) const = 0; - - // Return a CoordTransFunc from normial (original) coordinate to actual (transformed) coordinate - virtual CoordTransFunc ToActual(const tvm::Tensor& X) const = 0; - - // Return actual (transformed) shape in tvm::Array (tvm_codegen) - virtual tvm::Array ToActualShape(const tvm::Tensor& X) const = 0; - - // Return actual (transformed) shape in vector (ort) - virtual std::vector ToActualShape(const Tensor* X) const = 0; - - // Create Layout Marshalling op in outputs - void CreateLayoutMarshallingTVMOp(tvm::Array& inputs, - tvm::Array& outputs) const; - - // Layout name - const std::string& Name() const; - - protected: - std::string name_; - ONNX_NAMESPACE::TensorProto_DataType proto_type_; - int input_dim_; - float pad_zero_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayout); -}; - -// Weight Layout Registry is a registry holds all WeightLayout -using WeightLayoutRegistry = codegen::RegistryBase; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc index a086c90ea4b14..a79e7300cffce 100644 --- a/onnxruntime/core/common/logging/logging.cc +++ b/onnxruntime/core/common/logging/logging.cc @@ -64,13 +64,13 @@ LoggingManager* LoggingManager::GetDefaultInstance() { #pragma warning(disable : 26426) #endif -static OrtMutex& DefaultLoggerMutex() noexcept { - static OrtMutex mutex; +static std::mutex& DefaultLoggerMutex() noexcept { + static std::mutex mutex; return mutex; } Logger* LoggingManager::s_default_logger_ = nullptr; -OrtMutex sink_mutex_; +std::mutex sink_mutex_; #ifdef _MSC_VER #pragma warning(pop) @@ -107,7 +107,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min // lock mutex to create instance, and enable logging // this matches the mutex usage in Shutdown - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); if (DefaultLoggerManagerInstance().load() != nullptr) { ORT_THROW("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time."); @@ -127,7 +127,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min LoggingManager::~LoggingManager() { if (owns_default_logger_) { // lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance. - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); #if ((__cplusplus >= 201703L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L))) DefaultLoggerManagerInstance().store(nullptr, std::memory_order_release); #else @@ -283,7 +283,7 @@ Severity OverrideLevelWithEtw(Severity original_severity) { bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory, logging::Severity severity) { - std::lock_guard guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() != SinkType::CompositeSink) { // Current sink is not a composite, create a new composite sink and add the current sink to it auto new_composite = std::make_unique(); @@ -305,7 +305,7 @@ bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() == SinkType::CompositeSink) { auto composite_sink = static_cast(sink_.get()); diff --git a/onnxruntime/core/common/profiler.cc b/onnxruntime/core/common/profiler.cc index 71bca6ef3b582..8562e5524af74 100644 --- a/onnxruntime/core/common/profiler.cc +++ b/onnxruntime/core/common/profiler.cc @@ -85,7 +85,7 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category, custom_logger_->SendProfileEvent(event); } else { // TODO: sync_gpu if needed. - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); if (events_.size() < max_num_events_) { events_.emplace_back(std::move(event)); } else { @@ -115,7 +115,7 @@ std::string Profiler::EndProfiling() { LOGS(*session_logger_, INFO) << "Writing profiler data to file " << profile_stream_file_; } - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); profile_stream_ << "[\n"; for (const auto& ep_profiler : ep_profilers_) { diff --git a/onnxruntime/core/common/profiler.h b/onnxruntime/core/common/profiler.h index a0bca0007b245..0103d8abb151f 100644 --- a/onnxruntime/core/common/profiler.h +++ b/onnxruntime/core/common/profiler.h @@ -11,7 +11,7 @@ #include "core/common/profiler_common.h" #include "core/common/logging/logging.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -130,7 +130,7 @@ class Profiler { static std::atomic global_max_num_events_; // Mutex controlling access to profiler data - OrtMutex mutex_; + std::mutex mutex_; bool enabled_{false}; #if defined(__wasm__) /* diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 7b62de799b6fc..b192688373851 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -21,9 +21,10 @@ limitations under the License. #include "core/common/cpuid_info.h" #include "core/common/eigen_common_wrapper.h" #include "core/platform/EigenNonBlockingThreadPool.h" -#include "core/platform/ort_mutex.h" +#include #if !defined(ORT_MINIMAL_BUILD) #ifdef _WIN32 +#include #include "processthreadsapi.h" #include #include diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5dca4cf6c165b..ecd3960107926 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -138,7 +138,8 @@ class PlannerImpl { const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps, const InlinedHashMap& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) + const ISequentialPlannerContext& context, SequentialExecutionPlan& plan, + const logging::Logger& logger) : context_(&context), plan_(plan), parent_node_(parent_node), @@ -148,14 +149,15 @@ class PlannerImpl { kernel_create_info_map_(kernel_create_info_map), subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps), outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), - ort_value_name_idx_map_(ort_value_name_idx_map) {} + ort_value_name_idx_map_(ort_value_name_idx_map), + logger_(logger) { + } Status CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger); + const PathString& partition_config_file); private: gsl::not_null context_; @@ -183,6 +185,12 @@ class PlannerImpl { InlinedHashMap> dependence_graph_; InlinedHashMap value_node_map_; + // logger_ is not currently used in a minimal build +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + [[maybe_unused]] +#endif + const logging::Logger& logger_; + // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: struct OrtValueInfo { const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue @@ -213,6 +221,7 @@ class PlannerImpl { FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point) : ml_value(ort_value), deallocate_point(dealloc_point) {} }; + // freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when // they became free (more recently freed earlier in the list). std::list freelist_; @@ -225,7 +234,8 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), + "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -335,9 +345,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -361,9 +371,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -397,8 +407,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -448,8 +458,8 @@ class PlannerImpl { return true; } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " - << producer_node->Name() << " as it has external outputs."; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; #endif } } @@ -1198,9 +1208,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1241,9 +1251,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1290,8 +1300,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -1869,8 +1879,7 @@ class PlannerImpl { } #ifndef ORT_ENABLE_STREAM - void PartitionIntoStreams(const logging::Logger& /*logger*/, - const ExecutionProviders& /*execution_providers*/, + void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { if (graph_viewer_.NumberOfNodes() > 0) { stream_nodes_.push_back({}); @@ -1915,11 +1924,11 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { - auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); - auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); + void PartitionIntoStreams(const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { + auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file); + auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, + context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (size_t i = 0; i < stream_nodes_.size(); ++i) { @@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger) { + const PathString& partition_config_file) { // 1. partition graph into streams - PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file); + PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file); // 2. initialize the plan based on stream partition result int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; @@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan( PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_create_info_map, subgraphs_kernel_create_info_maps, outer_scope_node_arg_to_location_map, - ort_value_name_idx_map, context, *plan); + ort_value_name_idx_map, context, *plan, logger); return planner.CreatePlan( #ifdef ORT_ENABLE_STREAM stream_handle_registry, #endif - partition_config_file, - logger); + partition_config_file); } #ifdef ORT_ENABLE_STREAM diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index b6dc8ad56f257..26b98b0a04d24 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0 || strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || - strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index 7493ac7d0a4e8..edf965d3835b5 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -8,6 +8,8 @@ #include #include +#include + #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/framework/bfc_arena.h" @@ -75,4 +77,21 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { } } +bool DoesCpuAllocatorSupportArenaUsage() { +#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) + // We use these allocators instead of the arena. + return false; +#elif defined(ABSL_HAVE_ADDRESS_SANITIZER) + // Using the arena may hide memory issues. Disable it in an ASan build. + return false; +#else + // Disable the arena for 32-bit builds because it may run into an infinite loop on integer overflow. + if constexpr (sizeof(void*) == 4) { + return false; + } else { + return true; + } +#endif +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocator_utils.h b/onnxruntime/core/framework/allocator_utils.h index 7dda1d1a6fd8f..bef0b7057a7f8 100644 --- a/onnxruntime/core/framework/allocator_utils.h +++ b/onnxruntime/core/framework/allocator_utils.h @@ -42,4 +42,9 @@ struct AllocatorCreationInfo { // Valid values can be found in onnxruntime_c_api.h. AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info); +/** + * Gets whether a CPU allocator supports arena usage. + */ +bool DoesCpuAllocatorSupportArenaUsage(); + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 13f9656ae0595..6788b4af3b982 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -276,7 +276,7 @@ void* BFCArena::Reserve(size_t size) { if (size == 0) return nullptr; - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); LOGS_DEFAULT(INFO) << "Reserving memory in BFCArena for " << device_allocator_->Info().name << " size: " << size; @@ -293,7 +293,7 @@ void* BFCArena::Reserve(size_t size) { } size_t BFCArena::RequestedSize(const void* ptr) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); BFCArena::ChunkHandle h = region_manager_.get_handle(ptr); ORT_ENFORCE(h != kInvalidChunkHandle); BFCArena::Chunk* c = ChunkFromHandle(h); @@ -301,7 +301,7 @@ size_t BFCArena::RequestedSize(const void* ptr) { } size_t BFCArena::AllocatedSize(const void* ptr) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); BFCArena::ChunkHandle h = region_manager_.get_handle(ptr); ORT_ENFORCE(h != kInvalidChunkHandle); BFCArena::Chunk* c = ChunkFromHandle(h); @@ -325,7 +325,7 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, // The BFC allocator tries to find the best fit first. BinNum bin_num = BinNumForSize(rounded_bytes); - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); // search for a valid chunk auto* chunk = FindChunkPtr(bin_num, rounded_bytes, @@ -377,7 +377,7 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, } void BFCArena::GetStats(AllocatorStats* stats) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); *stats = stats_; } @@ -496,7 +496,7 @@ void BFCArena::Free(void* p) { if (p == nullptr) { return; } - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); auto it = reserved_chunks_.find(p); if (it != reserved_chunks_.end()) { device_allocator_->Free(it->first); @@ -509,7 +509,7 @@ void BFCArena::Free(void* p) { } Status BFCArena::Shrink() { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); auto num_regions = region_manager_.regions().size(); std::vector region_ptrs; std::vector region_sizes; @@ -807,7 +807,7 @@ void BFCArena::DumpMemoryLog(size_t num_bytes) { } #ifdef ORT_ENABLE_STREAM void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); for (const auto& region : region_manager_.regions()) { ChunkHandle region_begin_chunk = region_manager_.get_handle(region.ptr()); diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 5e4cd9f62f11b..8081738f2a5dc 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -27,7 +27,7 @@ limitations under the License. #include "core/common/logging/severity.h" #include "core/common/safeint.h" -#include "core/platform/ort_mutex.h" +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/allocator.h" @@ -489,7 +489,7 @@ class BFCArena : public IAllocator { std::unique_ptr device_allocator_; - mutable OrtMutex lock_; + mutable std::mutex lock_; RegionManager region_manager_; std::vector chunks_; diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 43fe92edc9dfe..29cf79ec385d8 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -12,6 +12,7 @@ #include "core/graph/graph_viewer.h" #include "core/common/logging/logging.h" #ifdef _WIN32 +#include #include #include #include "core/platform/tracing.h" diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc index ea6c499829391..fe73a55735631 100644 --- a/onnxruntime/core/framework/external_data_loader.cc +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -32,7 +32,7 @@ common::Status LoadWebAssemblyExternalData(const Env& env, if (typeof Module == 'undefined' || !Module.MountedFiles) { return 1; // "Module.MountedFiles" is not available. } - let fileName = UTF8ToString($0 >>> 0); + let fileName = UTF8ToString(Number($0 >>> 0)); if (fileName.startsWith('./')) { fileName = fileName.substring(2); } @@ -40,9 +40,9 @@ common::Status LoadWebAssemblyExternalData(const Env& env, if (!fileData) { return 2; // File not found in preloaded files. } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const dataIdOrBuffer = $3 >>> 0; + const offset = Number($1 >>> 0); + const length = Number($2 >>> 0); + const dataIdOrBuffer = Number($3 >>> 0); const loadType = $4; if (offset + length > fileData.byteLength) { diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index ef68b88187e08..1eb7420b44d2c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { + gsl::span tentative_nodes, + const logging::Logger& logger) { // automatic conversion from const std::vector& const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); InlinedVector node_id_to_order_map(graph.MaxNodeIndex()); @@ -83,7 +84,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { candidates.push(consumer_node->Index()); - LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); + LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); } } return Status::OK(); @@ -159,9 +160,9 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe if (place_in_cpu) { cpu_nodes.insert(cur); - LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() - << " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " - << " capable of executing this node"; + LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() + << " because the CPU execution path is deemed faster than overhead involved with execution " + "on other EPs capable of executing this node"; for (auto* output : node->OutputDefs()) { cpu_output_args.insert(output); } diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index c5bcd22888b7c..bca75adbfd5a7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -9,6 +9,9 @@ #include "core/graph/graph_viewer.h" namespace onnxruntime { +namespace logging { +class Logger; +} /** Returns a list of nodes that are preferred on CPU. @@ -19,6 +22,7 @@ namespace onnxruntime { */ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 4f745b74abce7..406fc1b15effc 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep, }; } // namespace -static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { +static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) { auto& current_ep = params.current_ep.get(); const auto& ep_type = current_ep.Type(); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) { - LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " + LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " "the layout transformer."; return Status::OK(); } @@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); @@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; // TODO: Provide EP with a capability to look inside the functions. capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); @@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, GraphPartitioner::Mode mode, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, - const layout_transformation::DebugGraphFn& debug_graph_fn) { + const layout_transformation::DebugGraphFn& debug_graph_fn, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn)); + transform_layout_fn, debug_graph_fn, logger)); } } @@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn)}; - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) { + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { nodes_to_compile.push_back(n); capabilities_to_compile.push_back(std::move(capability)); } else { @@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + logger, not_inlined, inlined_count)); } @@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -674,7 +681,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers context_cache_path, "' exist already."); } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), {}, logger); auto& ep_graph = ep_context_model.MainGraph(); ep_graph.SetDescription(graph.Description()); @@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, - partition_params.debug_graph_fn)); + partition_params.debug_graph_fn, + logger)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, - IExecutionProvider& current_ep) { + IExecutionProvider& current_ep, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability auto& graph = partition_params.graph.get(); @@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param auto& subgraph = *entry.second; PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep, + logger)); } } @@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param }; // clang-format on - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // Simplified partitioning where custom EPs may produce compiled nodes. static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); } return Status::OK(); @@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + logger, not_inlined, inlined_count)); @@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 0dd17d2f4a624..fac43bad0fefb 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { public: KernelLookup(ProviderType provider_type, gsl::span> kernel_registries, - const IKernelTypeStrResolver& kernel_type_str_resolver) + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) : provider_type_{provider_type}, kernel_registries_{kernel_registries}, - kernel_type_str_resolver_{kernel_type_str_resolver} { + kernel_type_str_resolver_{kernel_type_str_resolver}, + logger_{logger} { ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified."); } const KernelCreateInfo* LookUpKernel(const Node& node) const override { const KernelCreateInfo* kernel_create_info{}; for (const auto& registry : kernel_registries_) { - const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, + const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return kernel_create_info; @@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { ProviderType provider_type_; const gsl::span> kernel_registries_; const IKernelTypeStrResolver& kernel_type_str_resolver_; + const logging::Logger& logger_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index d695e0e04c2b0..8602a3b4004ff 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "core/framework/kernel_type_str_resolver.h" @@ -182,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { const auto& node_provider = node.GetExecutionProviderType(); const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); @@ -214,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } @@ -223,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out); + return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out); } Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); + return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out); } static bool KernelDefCompatible(int version, const KernelDef& kernel_def, @@ -260,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); if (out) *out = nullptr; @@ -288,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } @@ -310,9 +315,12 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) { for (auto i = range.first; i != range.second; ++i) { if (i->second.kernel_def && i->second.kernel_def->IsConflict(*create_info.kernel_def)) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Failed to add kernel for " + key + - ": Conflicting with a registered kernel with op versions."); + int since_version = i->second.kernel_def->SinceVersion().first; + std::string since_version_str = std::to_string(since_version); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to add kernel for ", key, + ": Conflicting with a registered kernel with op versions. the since version is: ", + since_version_str); } } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index f8ccdb8fb0238..721353854a474 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptrTryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, } if (p != nullptr) { - status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } -bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { +bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, + const Node& node, + const std::string& provider_type, + const logging::Logger& logger) { const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { - return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver()); + return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(), + logger); }); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 201fda6d978b6..72f0ed3c6268a 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -12,7 +12,7 @@ #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" #include "core/graph/graph_viewer.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { struct KernelCreateInfo; @@ -67,13 +67,14 @@ class KernelRegistryManager { // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status SearchKernelRegistry(const Node& node, + Status SearchKernelRegistry(const Node& node, const logging::Logger& logger, /*out*/ const KernelCreateInfo** kernel_create_info) const; /** * Whether this node can be run on this provider */ - static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); + static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type, + const logging::Logger& logger); Status CreateKernel(const Node& node, const IExecutionProvider& execution_provider, diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h index 587be491b360a..a642649eca341 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -18,7 +18,7 @@ #include "core/common/status.h" #include "core/graph/op_identifier.h" #include "core/graph/graph.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -129,7 +129,7 @@ class OpSchemaKernelTypeStrResolver final : public IKernelTypeStrResolver { // used as a cache when resolving // since the cache may be modified with a const instance, ensure that access to the cache is thread-safe mutable KernelTypeStrResolver resolver_; - mutable OrtMutex resolver_mutex_; + mutable std::mutex resolver_mutex_; }; #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/mem_pattern_planner.h b/onnxruntime/core/framework/mem_pattern_planner.h index f4db5d9f1c75f..e4353ec22db92 100644 --- a/onnxruntime/core/framework/mem_pattern_planner.h +++ b/onnxruntime/core/framework/mem_pattern_planner.h @@ -20,7 +20,7 @@ limitations under the License. #include "core/common/safeint.h" #include "core/framework/mem_pattern.h" #include "core/framework/allocation_planner.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { // MemPatternPlanner is used to trace allocation/free steps @@ -68,7 +68,7 @@ class MemPatternPlanner { void TraceAllocation(int ml_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size) { ORT_ENFORCE(using_counters_); - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); if (size == 0) { allocs_.emplace_back(ml_value_idx, MemoryBlock(0, 0)); @@ -133,7 +133,7 @@ class MemPatternPlanner { void TraceAllocation(int ml_value_idx, size_t size) { ORT_ENFORCE(!using_counters_); - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); if (size == 0) { allocs_.emplace_back(ml_value_idx, MemoryBlock(0, 0)); @@ -190,7 +190,7 @@ class MemPatternPlanner { } void TraceFree(int ml_value_index) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); for (auto it = blocks_.begin(); it != blocks_.end(); it++) { if (allocs_[*it].index_ == ml_value_index) { @@ -201,7 +201,7 @@ class MemPatternPlanner { } MemoryPattern GenerateMemPattern() const { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); #ifdef ENABLE_TRAINING if (using_counters_) { @@ -261,7 +261,7 @@ class MemPatternPlanner { std::list blocks_; SafeInt buffer_size_{0}; bool using_counters_; - mutable OrtMutex lock_; + mutable std::mutex lock_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/model_metadef_id_generator.cc b/onnxruntime/core/framework/model_metadef_id_generator.cc index 8b1d1f4f304c9..4a35052d159a0 100644 --- a/onnxruntime/core/framework/model_metadef_id_generator.cc +++ b/onnxruntime/core/framework/model_metadef_id_generator.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include #include "model_metadef_id_generator.h" -#include "core/platform/ort_mutex.h" +#include #include "core/graph/graph_viewer.h" #include "core/framework/murmurhash3.h" @@ -11,8 +11,8 @@ int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_vi HashValue& model_hash) const { // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. // use a lock when generating an id to be paranoid - static OrtMutex mutex; - std::lock_guard lock(mutex); + static std::mutex mutex; + std::lock_guard lock(mutex); model_hash = 0; // find the top level graph diff --git a/onnxruntime/core/framework/prepacked_weights_container.h b/onnxruntime/core/framework/prepacked_weights_container.h index 7fe317b6c4317..37fc01c05f2ae 100644 --- a/onnxruntime/core/framework/prepacked_weights_container.h +++ b/onnxruntime/core/framework/prepacked_weights_container.h @@ -11,7 +11,7 @@ #include "core/framework/buffer_deleter.h" #include "core/framework/allocator.h" -#include "core/platform/ort_mutex.h" +#include #include "prepacked_weights.h" namespace onnxruntime { @@ -53,7 +53,7 @@ class PrepackedWeightsContainer final { // PrePack() methods and does the read/write into the pre-packed weights' container. // We only want to invoke PrePack() on a kernel that doesn't have a cached version // of its pre-packed weight. - OrtMutex mutex_; + std::mutex mutex_; // Define allocators ahead of the container containing tensors because the allocators // needs to destructed after the container containing the pre-packed cached tensors diff --git a/onnxruntime/core/framework/random_generator.h b/onnxruntime/core/framework/random_generator.h index 39f31b2f9af8a..b0aa3df09ca62 100644 --- a/onnxruntime/core/framework/random_generator.h +++ b/onnxruntime/core/framework/random_generator.h @@ -7,7 +7,7 @@ #include #include -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -57,7 +57,7 @@ class PhiloxGenerator { * Resets the seed and offset. */ void SetSeed(uint64_t seed) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); seed_ = seed; offset_ = 0; } @@ -66,7 +66,7 @@ class PhiloxGenerator { * Gets the seed and offset pair, incrementing the offset by the specified count. */ std::pair NextPhiloxSeeds(uint64_t count) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto seeds = std::make_pair(seed_, offset_); offset_ += count; return seeds; @@ -79,7 +79,7 @@ class PhiloxGenerator { static PhiloxGenerator& Default(); private: - OrtMutex mutex_; + std::mutex mutex_; uint64_t seed_; uint64_t offset_; }; diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index aa762ca32fdb4..2185b8332b9cf 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -68,7 +68,7 @@ static void CalculateTotalOutputSizes(OpKernelContextInternal* op_kernel_context int output_count = op_kernel_context->OutputCount(); for (auto i = 0; i < output_count; i++) { const OrtValue* p_output = op_kernel_context->GetOutputMLValue(i); - if (p_output != nullptr && p_output->IsTensor()) { + if (p_output != nullptr && p_output->IsTensor() && p_output->IsAllocated()) { const auto& tensor = p_output->Get(); size_t tensor_size = tensor.SizeInBytes(); #if defined(TRACE_EXECUTION) @@ -104,7 +104,7 @@ static void CalculateTotalInputSizes(const OpKernelContextInternal* op_kernel_co const int input_count = op_kernel_context->InputCount(); for (auto i = 0; i < input_count; i++) { const OrtValue* p_input = op_kernel_context->GetInputMLValue(i); - if (p_input != nullptr && p_input->IsTensor()) { + if (p_input != nullptr && p_input->IsTensor() && p_input->IsAllocated()) { const OpKernelInfo& op_kernel_info = p_op_kernel->Info(); const Tensor* p_tensor = nullptr; bool is_param = op_kernel_info.TryGetConstantInput(i, &p_tensor); @@ -339,12 +339,6 @@ class KernelScope { if (session_state_.Profiler().IsEnabled()) { auto& node = kernel.Node(); node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name(); - auto& profiler = session_state_.Profiler(); - auto sync_time_begin = profiler.Start(); - profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name_ + "_fence_before", - sync_time_begin, - {{"op_name", kernel_.KernelDef().OpName()}}); concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool()); VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_; kernel_begin_time_ = session_state_.Profiler().Start(); @@ -381,11 +375,6 @@ class KernelScope { {"thread_scheduling_stats", concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())}, }); - auto sync_time_begin = profiler.Start(); - profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name_ + "_fence_after", - sync_time_begin, - {{"op_name", kernel_.KernelDef().OpName()}}); } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 4df0370ac719e..0ac2271ba09f1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -5,7 +5,7 @@ #include -#include "core/platform/ort_mutex.h" +#include #include "core/common/logging/logging.h" #include "core/common/safeint.h" #include "core/flatbuffers/schema/ort.fbs.h" @@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); if (!status.IsOK() && saving_ort_format) { // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. // in that case we assigned the node to that EP but do not compile it into a fused node. @@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. // if that's not possible for some reason we can fallback to the CPU EP implementation. node.SetExecutionProviderType(kCpuExecutionProvider); - status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } ORT_RETURN_IF_ERROR(status); @@ -518,7 +518,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap l(prepacked_weights_container_->mutex_); + std::lock_guard l(prepacked_weights_container_->mutex_); return prepacked_constant_weights(true); } else { return prepacked_constant_weights(false); @@ -775,7 +775,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup( const InlinedHashMap*& out_inferred_shapes) const { out_inferred_shapes = nullptr; int64_t key = CalculateMemoryPatternsKey(tensor_inputs); - std::lock_guard lock(mem_patterns_lock_); + std::lock_guard lock(mem_patterns_lock_); auto it = mem_patterns_.find(key); if (it == mem_patterns_.end()) { #ifdef ENABLE_TRAINING @@ -851,7 +851,7 @@ Status SessionState::UpdateMemoryPatternGroupCache(gsl::span ten MemoryPatternGroup mem_patterns) const { int64_t key = CalculateMemoryPatternsKey(tensor_inputs); - std::lock_guard lock(mem_patterns_lock_); + std::lock_guard lock(mem_patterns_lock_); // Do not update if present, as the pointer to the existing one is cached mem_patterns_.emplace(key, std::move(mem_patterns)); return Status::OK(); @@ -1588,7 +1588,7 @@ static void BindToDeviceStream(const SequentialExecutionPlan& execution_plan, std::unique_ptr SessionState::AcquireDeviceStreamCollection() const { if (has_device_stream_enabled_ep_) { - std::lock_guard lock(device_stream_pool_mutex_); + std::lock_guard lock(device_stream_pool_mutex_); if (!device_stream_pool_.empty()) { auto device_stream = std::move(device_stream_pool_.back()); device_stream_pool_.pop_back(); @@ -1607,7 +1607,7 @@ std::unique_ptr SessionState::AcquireDeviceStreamCollect void SessionState::RecycleDeviceStreamCollection(std::unique_ptr device_stream_collection) const { // if no need to reuse the device stream, don't perform the recycle if (has_device_stream_enabled_ep_) { - std::lock_guard lock(device_stream_pool_mutex_); + std::lock_guard lock(device_stream_pool_mutex_); device_stream_pool_.push_back(std::move(device_stream_collection)); } else { device_stream_collection.reset(nullptr); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 5b7f6dc5cb867..e1674ba4b690b 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -35,7 +35,7 @@ #include "core/framework/ort_value_name_idx_map.h" #include "core/graph/graph_viewer.h" #include "core/graph/onnx_protobuf.h" -#include "core/platform/ort_mutex.h" +#include #include "core/platform/path_lib.h" #include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) @@ -494,7 +494,7 @@ class SessionState { bool enable_mem_pattern_; // lock for the mem_patterns_ - mutable OrtMutex mem_patterns_lock_; + mutable std::mutex mem_patterns_lock_; // cache for the generated mem_patterns. key is calculated based on input shapes. // must be a node based container as a pointer is cached. mutable NodeHashMap mem_patterns_; @@ -568,7 +568,7 @@ class SessionState { std::unique_ptr stream_handles_registry_; // lock for the device stream pool - mutable OrtMutex device_stream_pool_mutex_; + mutable std::mutex device_stream_pool_mutex_; mutable std::vector> device_stream_pool_; // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 74c359881a1d7..2af9f95ad059e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), - "Tensor does not have external data to read from."); - - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), - "External data type cannot be UNDEFINED or STRING."); - - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - - const auto& location = external_data_info->GetRelPath(); - - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); - - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); - const size_t external_data_length = external_data_info->GetLength(); - ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, - "TensorProto: ", tensor_proto.name(), - " external data size mismatch. Computed size: ", *&tensor_byte_size, - ", external_data.length: ", external_data_length); - - file_offset = external_data_info->GetOffset(); - - return Status::OK(); -} - // Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully. // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. @@ -261,6 +230,37 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { + ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + "Tensor does not have external data to read from."); + + ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + "External data type cannot be UNDEFINED or STRING."); + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + + const auto& location = external_data_info->GetRelPath(); + + external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + const size_t external_data_length = external_data_info->GetLength(); + ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, + "TensorProto: ", tensor_proto.name(), + " external data size mismatch. Computed size: ", *&tensor_byte_size, + ", external_data.length: ", external_data_length); + + file_offset = external_data_info->GetOffset(); + + return Status::OK(); +} + void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { tensor_proto.set_raw_data(std::move(param)); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 227ba0706197e..262f7adaca1cb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -23,6 +23,20 @@ namespace onnxruntime { namespace utils { +/** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @returns Status::OK() if the function is executed successfully + */ +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size); /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 304fffa4ab7ca..96657d482d3a8 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -7,7 +7,7 @@ #include #include "core/common/common.h" -#include "core/platform/ort_mutex.h" +#include #include "core/framework/allocator.h" #include "core/framework/tuning_results.h" @@ -77,7 +77,7 @@ class TuningResultsManager { void Clear(); private: - mutable OrtMutex lock_; + mutable std::mutex lock_; std::unordered_map results_; }; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9eed0249711f9..ff664c2c76703 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -57,7 +57,6 @@ void DestroyStrings(void* p_data, int64_t elements) { bool ProviderIsCpuBased(const std::string& provider_type) { return provider_type == onnxruntime::kCpuExecutionProvider || provider_type == onnxruntime::kDnnlExecutionProvider || - provider_type == onnxruntime::kTvmExecutionProvider || provider_type == onnxruntime::kVitisAIExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index c706c6fc5ff5f..f2a2a52f8334f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -787,14 +787,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(7, "beam_width", - "The beam width that is being used while decoding." + "The beam width that is being used while decoding. " "If not provided, the beam width will be assumed to be 1.", "M", OpSchema::Optional) .Input(8, "cache_indirection", - "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" - "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", + "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " + "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", OpSchema::Optional) .Output(0, @@ -871,7 +871,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "attention_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(5, @@ -902,15 +902,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(8, "beam_width", - "The beam width that is being used while decoding." + "The beam width that is being used while decoding. " "If not provided, the beam width will be assumed to be 1.", "M", OpSchema::Optional) .Input(9, "cache_indirection", - // This input is useful for CUDA EP only. - "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" - "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", + "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " + "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", OpSchema::Optional) .Input(10, @@ -940,7 +939,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Output(3, "qk", - "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). ", "V", OpSchema::Optional) .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 115db369d2af0..c7a0793c4748f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1067,11 +1067,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, ONNX_MS_OPERATOR_SET_SCHEMA( UnfoldTensor, 1, OpSchema() - .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " - "Step between two slices is given by step. " - "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " - "the returned tensor will be (sizedim - size) / step + 1. " - "An additional dimension of size size is appended in the returned tensor.") + .SetDoc("Returns a tensor which contains all slices of size `size` from input tensor in the dimension `dim`. " + "Step between two slices is given by `step`. " + "If `sizedim` is the size of dimension `dim` for input tensor, the size of dimension `dim` in " + "the returned tensor will be `(sizedim - size) / step + 1`. " + "An additional dimension of size `size` is appended in the returned tensor.") .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) .Attr("size", "specify the size", AttributeProto::INT) .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) @@ -1122,7 +1122,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " - "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "dynamic time warping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") @@ -3335,6 +3335,11 @@ void RegisterContribSchemas() { AttributeProto::STRING, OPTIONAL_VALUE) .Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE) + .Attr( + "max_size", + "max size in the context. Usage depend on the EP.", + AttributeProto::INT, + static_cast(0)) .AllowUncheckedAttributes() .Input( 0, diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6f1f1c831d191..5a3cd86b04492 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -9,7 +9,7 @@ #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/shape_inference_functions.h" -#include "onnx/onnx-ml.pb.h" // ? +#include "core/graph/onnx_protobuf.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build @@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx); void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); namespace defs::math::utils { - void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); } } // namespace ONNX_NAMESPACE @@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } } - if (all_lengths_known) { - output_shape->mutable_dim(axis)->set_dim_value(total_length); - } - })); + if (all_lengths_known) { + output_shape->mutable_dim(axis)->set_dim_value(total_length); + } + })); ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema() .SetDoc("Return elements, either from X or Y, depending on condition.") @@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT, static_cast(0)) .Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is " + .Attr("past_present_share_buffer", + "Corresponding past and present are same tensor, its shape is " "(2, batch_size, num_heads, max_sequence_length, head_size)", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index e67884e3875d8..1bae63b510563 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -554,8 +554,8 @@ static Status SaveModel(Model& model, const T& file_path) { model_proto.SerializeToArray(buffer, buffer_size); EM_ASM(({ - const buffer = $0; - const buffer_size = $1; + const buffer = Number($0); + const buffer_size = Number($1); const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); @@ -570,9 +570,9 @@ static Status SaveModel(Model& model, const T& file_path) { window.open(url, '_blank'); } }), - reinterpret_cast(buffer), - static_cast(buffer_size), - reinterpret_cast(file_path.c_str())); + buffer, + buffer_size, + file_path.c_str()); free(buffer); return Status::OK(); diff --git a/onnxruntime/core/graph/schema_registry.cc b/onnxruntime/core/graph/schema_registry.cc index a7d94f4571d96..496825f00d452 100644 --- a/onnxruntime/core/graph/schema_registry.cc +++ b/onnxruntime/core/graph/schema_registry.cc @@ -10,7 +10,7 @@ common::Status OnnxRuntimeOpSchemaRegistry::SetBaselineAndOpsetVersionForDomain( const std::string& domain, int baseline_opset_version, int opset_version) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto it = domain_version_range_map_.find(domain); if (domain_version_range_map_.end() != it) { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 28ae64c4d5b3e..207c058d899b4 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1435,6 +1435,29 @@ MLAS_FP16* Destination, size_t Count ); +/** + * @brief rotary embedding for one hidden state vector + * + * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported. + * @param input: input tensor, of shape [dim] + * @param sin: sin tensor, of shape [dim/2] + * @param cos: cos tensor, of shape [dim/2] + * @param dim: dimension of rotary embedding + * @param interleaved: whether the real part and imaginary parts are interleaved + * @param output: output tensor, of shape [dim] + */ +template +void +MLASCALL +MlasRotaryEmbedOneRow( + const T* input, + const T* sin, + const T* cos, + size_t dim, + bool interleaved, + T* output +); + /** * @brief Whether current CPU supports FP16 acceleration. */ diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 232bf2261ef4c..9608644a22523 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,51 +27,50 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ +} MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. + * + * @tparam T data type of input A */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) +template +struct MLAS_QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; /** * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix + * A must be a float32/16 matrix * B must be a quantized and packed n-bit int matrix * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). + * Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasQNBitGemmPackQuantBData(). * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should * point to an intermediate workspace buffer. * + * @tparam T data type of input A * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform. * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). + * If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch(). + * If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasQNBitGemmBatch(). * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp similarity index 99% rename from onnxruntime/core/mlas/lib/fp16_neon_common.cpp rename to onnxruntime/core/mlas/lib/cast_kernel_neon.cpp index 29734c2277667..8a385c9c61751 100644 --- a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp +++ b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - fp16_neon_common.cpp + cast_kernel_neon.cpp Abstract: diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 30b66cdb2ea78..f4c49905ebbd7 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,15 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +template +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) { + return vreinterpret_f16_u16( + vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane) + ); +} + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); } +template +MLAS_FORCEINLINE +void +MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane); +} + MLAS_FORCEINLINE void MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..69e37d2b916d1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -0,0 +1,898 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hqnbitgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_QNBIT_GEMM_COMPUTE_TYPE HQNBIT_CompFp16. + +--*/ + +#include + +#include +#include +#include + +#include "fp16_common.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ +MLAS_FORCEINLINE void +Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, + uint8x8_t& v4, uint8x8_t& v5, uint8x8_t& v6, uint8x8_t& v7) +{ + // v0: | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // v1: | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // v2: | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // v3: | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // v4: | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // v5: | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // v6: | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // v7: | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + + uint8x8x2_t a0 = vtrn_u8(v0, v1); + uint8x8x2_t a1 = vtrn_u8(v2, v3); + uint8x8x2_t a2 = vtrn_u8(v4, v5); + uint8x8x2_t a3 = vtrn_u8(v6, v7); + + // a0[0]: | B00 B10 | B01 B11 | B40 B50 | B41 B51 | B80 B90 | B81 B91 | Bc0 Bd0 | Bc1 Bd1 | + // a0[1]: | B20 B30 | B21 B31 | B60 B70 | B61 B71 | Ba0 Bb0 | Ba1 Bb1 | Be0 Bf0 | Be1 Bf1 | + // a1[0]: | B02 B12 | B03 B13 | B42 B52 | B43 B53 | B82 B92 | B83 B93 | Bc2 Bd2 | Bc3 Bd3 | + // a1[1]: | B22 B32 | B23 B33 | B62 B72 | B63 B73 | Ba2 Bb2 | Ba3 Bb3 | Be2 Bf2 | Be3 Bf3 | + // a2[0]: | B04 B14 | B05 B15 | B44 B54 | B45 B55 | B84 B94 | B85 B95 | Bc4 Bd4 | Bc5 Bd5 | + // a2[1]: | B24 B34 | B25 B35 | B64 B74 | B65 B75 | Ba4 Bb4 | Ba5 Bb5 | Be4 Bf4 | Be5 Bf5 | + // a3[0]: | B06 B16 | B07 B17 | B46 B56 | B47 B57 | B86 B96 | B87 B97 | Bc6 Bd6 | Bc7 Bd7 | + // a3[1]: | B26 B36 | B27 B37 | B66 B76 | B67 B77 | Ba6 Bb6 | Ba7 Bb7 | Be6 Bf6 | Be7 Bf7 | + + uint16x4x2_t b0 = vtrn_u16(vreinterpret_u16_u8(a0.val[0]), vreinterpret_u16_u8(a1.val[0])); + uint16x4x2_t b1 = vtrn_u16(vreinterpret_u16_u8(a0.val[1]), vreinterpret_u16_u8(a1.val[1])); + uint16x4x2_t b2 = vtrn_u16(vreinterpret_u16_u8(a2.val[0]), vreinterpret_u16_u8(a3.val[0])); + uint16x4x2_t b3 = vtrn_u16(vreinterpret_u16_u8(a2.val[1]), vreinterpret_u16_u8(a3.val[1])); + + // b0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B80 B90 | B81 B91 | B82 B92 | B83 B93 | + // b0[1]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | + // b1[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | + // b1[1]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | + // b2[0]: | B04 B14 | B05 B15 | B06 B16 | B07 B17 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // b2[1]: | B44 B54 | B45 B55 | B46 B56 | B47 B57 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // b3[0]: | B24 B34 | B25 B35 | B26 B36 | B27 B37 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // b3[1]: | B64 B74 | B65 B75 | B66 B76 | B67 B77 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), vreinterpret_u32_u16(b2.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), vreinterpret_u32_u16(b2.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b1.val[0]), vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b1.val[1]), vreinterpret_u32_u16(b3.val[1])); + + // c0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // c0[1]: | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // c1[0]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // c1[1]: | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // c2[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // c2[1]: | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // c3[0]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // c3[1]: | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + v0 = vreinterpret_u8_u32(c0.val[0]); + v1 = vreinterpret_u8_u32(c2.val[0]); + v2 = vreinterpret_u8_u32(c1.val[0]); + v3 = vreinterpret_u8_u32(c3.val[0]); + v4 = vreinterpret_u8_u32(c0.val[1]); + v5 = vreinterpret_u8_u32(c2.val[1]); + v6 = vreinterpret_u8_u32(c1.val[1]); + v7 = vreinterpret_u8_u32(c3.val[1]); +} + +MLAS_FORCEINLINE void +Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE void +Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) +{ + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + constexpr size_t nbits = 4; + constexpr size_t k_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % k_blk_dim == 0); + + const size_t k_blk_num = MlasDivRoundup(K, k_blk_dim); + const size_t n_blk_num = MlasDivRoundup(N, n_blk_dim); + constexpr size_t k_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, k_blk_dim); + const size_t iterations = k_blk_num * n_blk_num; // one iteration per block + const size_t ld = MlasDivRoundup(K, BlkLen) * MlasQNBitBlkDataSizeInBytes(nbits, BlkLen); + + // + // For blocks 16_K * 8_N, transpose bytes in 8x8 blocks like this: + // src B_k_n: + // | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + // => dst: + // | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + // + + // + // For blocks < 8_N: + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + MlasTrySimpleParallel( + ThreadPool, iterations, + [&](ptrdiff_t tid) { + const size_t n_blk = tid / k_blk_num; + const size_t k_blk = tid % k_blk_num; + size_t n = n_blk * n_blk_dim; + const size_t src_offset = n * ld + k_blk * k_blk_bytes; + + if (n + n_blk_dim <= N) { + const size_t dst_offset = n * ld + k_blk * k_blk_bytes * n_blk_dim; + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + dst_offset; + + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v1 = vld1_u8(src + ld); + uint8x8_t v2 = vld1_u8(src + 2*ld); + uint8x8_t v3 = vld1_u8(src + 3*ld); + uint8x8_t v4 = vld1_u8(src + 4*ld); + uint8x8_t v5 = vld1_u8(src + 5*ld); + uint8x8_t v6 = vld1_u8(src + 6*ld); + uint8x8_t v7 = vld1_u8(src + 7*ld); + + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + vst1_u8(dst, v0); + vst1_u8(dst + 8, v1); + vst1_u8(dst + 16, v2); + vst1_u8(dst + 24, v3); + vst1_u8(dst + 32, v4); + vst1_u8(dst + 40, v5); + vst1_u8(dst + 48, v6); + vst1_u8(dst + 56, v7); + } else { + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + src_offset; + + for (; n < N; ++n, src += ld, dst += ld) { + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v_even = vand_u8(v0, vdup_n_u8(0x0F)); + uint8x8_t v_odd = vshr_n_u8(v0, 4); + uint8x8x2_t v1 = vzip_u8(v_even, v_odd); + uint8x8_t v2 = vorr_u8(v1.val[0], vshl_n_u8(v1.val[1], 4)); + vst1_u8(dst, v2); + } + } + } + ); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t b01 = vld1_u8(src_ptr); + uint8x8_t b23 = vld1_u8(src_ptr + 8); + uint8x8_t b45 = vld1_u8(src_ptr + 16); + uint8x8_t b67 = vld1_u8(src_ptr + 24); + uint8x8_t b89 = vld1_u8(src_ptr + 32); + uint8x8_t bab = vld1_u8(src_ptr + 40); + uint8x8_t bcd = vld1_u8(src_ptr + 48); + uint8x8_t bef = vld1_u8(src_ptr + 56); + + float16x8_t b0 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b01, low_mask), 0)); + float16x8_t b1 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b01, 4), 0)); + float16x8_t b2 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b23, low_mask), 0)); + float16x8_t b3 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b23, 4), 0)); + float16x8_t b4 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b45, low_mask), 0)); + float16x8_t b5 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b45, 4), 0)); + float16x8_t b6 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b67, low_mask), 0)); + float16x8_t b7 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b67, 4), 0)); + float16x8_t b8 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b89, low_mask), 0)); + float16x8_t b9 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b89, 4), 0)); + float16x8_t ba = vcvtq_f16_u16(vshll_n_u8(vand_u8(bab, low_mask), 0)); + float16x8_t bb = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bab, 4), 0)); + float16x8_t bc = vcvtq_f16_u16(vshll_n_u8(vand_u8(bcd, low_mask), 0)); + float16x8_t bd = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bcd, 4), 0)); + float16x8_t be = vcvtq_f16_u16(vshll_n_u8(vand_u8(bef, low_mask), 0)); + float16x8_t bf = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bef, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, b0, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, b1, scale); + float16x8_t c2 = vfmaq_f16(neg_scaled_zp, b2, scale); + float16x8_t c3 = vfmaq_f16(neg_scaled_zp, b3, scale); + float16x8_t c4 = vfmaq_f16(neg_scaled_zp, b4, scale); + float16x8_t c5 = vfmaq_f16(neg_scaled_zp, b5, scale); + float16x8_t c6 = vfmaq_f16(neg_scaled_zp, b6, scale); + float16x8_t c7 = vfmaq_f16(neg_scaled_zp, b7, scale); + float16x8_t c8 = vfmaq_f16(neg_scaled_zp, b8, scale); + float16x8_t c9 = vfmaq_f16(neg_scaled_zp, b9, scale); + float16x8_t ca = vfmaq_f16(neg_scaled_zp, ba, scale); + float16x8_t cb = vfmaq_f16(neg_scaled_zp, bb, scale); + float16x8_t cc = vfmaq_f16(neg_scaled_zp, bc, scale); + float16x8_t cd = vfmaq_f16(neg_scaled_zp, bd, scale); + float16x8_t ce = vfmaq_f16(neg_scaled_zp, be, scale); + float16x8_t cf = vfmaq_f16(neg_scaled_zp, bf, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); + MlasStoreFloat16x8(dst_ptr + 16, c2); + MlasStoreFloat16x8(dst_ptr + 24, c3); + MlasStoreFloat16x8(dst_ptr + 32, c4); + MlasStoreFloat16x8(dst_ptr + 40, c5); + MlasStoreFloat16x8(dst_ptr + 48, c6); + MlasStoreFloat16x8(dst_ptr + 56, c7); + MlasStoreFloat16x8(dst_ptr + 64, c8); + MlasStoreFloat16x8(dst_ptr + 72, c9); + MlasStoreFloat16x8(dst_ptr + 80, ca); + MlasStoreFloat16x8(dst_ptr + 88, cb); + MlasStoreFloat16x8(dst_ptr + 96, cc); + MlasStoreFloat16x8(dst_ptr + 104, cd); + MlasStoreFloat16x8(dst_ptr + 112, ce); + MlasStoreFloat16x8(dst_ptr + 120, cf); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 1 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t v0 = vld1_u8(src_ptr); + + float16x8_t f_low = vcvtq_f16_u16(vshll_n_u8(vand_u8(v0, low_mask), 0)); + float16x8_t f_high = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(v0, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, f_low, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, f_high, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); +} + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +) { + MLAS_UNREFERENCED_PARAMETER(K); + constexpr size_t nbits = 4; + constexpr size_t kk_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % kk_blk_dim == 0); + + const size_t kk_blk_num = BlockCountK * BlkLen / kk_blk_dim; + constexpr size_t kk_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, kk_blk_dim); + const size_t kk_n_src_bytes = kk_blk_bytes * n_blk_dim; + const size_t kk_n_dst_size = kk_blk_dim * n_blk_dim; + const size_t ld_blk_src = kk_blk_num * kk_n_src_bytes; + const size_t ld_blk_dst = BlkLen * BlockCountK * n_blk_dim; + const size_t ld_blk_scale = BlockCountK * n_blk_dim; + const size_t ld_zp = (BlockCountK + 1) / 2; + const size_t ld_blk_zp = ld_zp * n_blk_dim; + const float16x8_t zp_mid_point_vec = MlasBroadcastFloat16x8(MLAS_FP16(8.0f).val); + const bool has_zp = QuantBZeroPoint != nullptr; + + size_t n = 0; + for (; n + n_blk_dim <= CountN; n += n_blk_dim) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + const std::uint8_t* src_ptr = reinterpret_cast(QuantBData); + auto* dst_ptr = reinterpret_cast<_mlas_fp16_*>(FpData); + + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + // prepare scales and zero_points for the block + _mlas_fp16_ scales[n_blk_dim]; + uint16_t zero_points[n_blk_dim]; + float16x8_t scale_vec; + float16x8_t neg_scaled_zp_vec; + + UnrolledLoop([&](int nn){ + scales[nn] = scales_ptr[nn * BlockCountK]; + }); + scale_vec = MlasLoadFloat16x8(scales); + + if (has_zp) { + UnrolledLoop([&](int nn){ + uint8_t zp = zero_points_ptr[nn * ld_zp]; + zp = (k_blk_i & 1) ? (zp >> 4) : (zp & 0x0F); + zero_points[nn] = static_cast(zp); + }); + uint16x8_t zp_u16_vec = vld1q_u16(zero_points); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<8, 16>(src_ptr, scale_vec, neg_scaled_zp_vec, dst_ptr); + + src_ptr += kk_n_src_bytes; + dst_ptr += kk_n_dst_size; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBData += ld_blk_src; + FpData += ld_blk_dst; + QuantBScale += ld_blk_scale; + QuantBZeroPoint = has_zp ? QuantBZeroPoint + ld_blk_zp : nullptr; + } + + // remaining N + for (; n < CountN; ++n) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + const auto scale = scales_ptr[0]; + float16x8_t scale_vec = MlasBroadcastFloat16x8(scale); + float16x8_t neg_scaled_zp_vec; + + if (has_zp) { + uint8_t zero_point = static_cast(zero_points_ptr[0]); + zero_point = (k_blk_i & 1) ? (zero_point >> 4) : (zero_point & 0x0F); + uint16x8_t zp_u16_vec = vdupq_n_u16(static_cast(zero_point)); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<1, 16>( + reinterpret_cast(QuantBData), scale_vec, neg_scaled_zp_vec, + reinterpret_cast<_mlas_fp16_*>(FpData) + ); + + QuantBData += kk_blk_bytes; + FpData += kk_blk_dim; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBScale += BlockCountK; + if (has_zp) { + QuantBZeroPoint += ld_zp; + } + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8), float16x8_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x8(Bias); + } else { + return MlasZeroFloat16x8(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 4), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x4(Bias); + } else { + return MlasZeroFloat16x4(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N == 2 || N == 1)), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + float16x4_t v = MlasZeroFloat16x4(); + + if (Bias) { + v = MlasLoadLaneFloat16x4<0>(Bias, v); + if constexpr (N == 2) { + v = MlasLoadLaneFloat16x4<1>(Bias + 1, v); + } + return v; + } else { + return v; + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 8), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x8_t a0 = MlasLoadFloat16x8(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + float16x8_t b4 = MlasLoadFloat16x8(B + 32); + float16x8_t b5 = MlasLoadFloat16x8(B + 40); + float16x8_t b6 = MlasLoadFloat16x8(B + 48); + float16x8_t b7 = MlasLoadFloat16x8(B + 56); + + // This version uses less instructions, but introduces dependency path between instructions. + // Must pair it with loop unrolling to alleviate dependency path penalty. + float16x8_t c0 = vfmaq_laneq_f16(accumulator, b0, a0, 0); + c0 = vfmaq_laneq_f16(c0, b1, a0, 1); + c0 = vfmaq_laneq_f16(c0, b2, a0, 2); + c0 = vfmaq_laneq_f16(c0, b3, a0, 3); + c0 = vfmaq_laneq_f16(c0, b4, a0, 4); + c0 = vfmaq_laneq_f16(c0, b5, a0, 5); + c0 = vfmaq_laneq_f16(c0, b6, a0, 6); + c0 = vfmaq_laneq_f16(c0, b7, a0, 7); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 4), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0); + c0 = vfmaq_lane_f16(c0, b1, a0, 1); + c0 = vfmaq_lane_f16(c0, b2, a0, 2); + c0 = vfmaq_lane_f16(c0, b3, a0, 3); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && (K == 2 || K == 1)), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasZeroFloat16x4(); + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + if constexpr (K == 2) a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + float16x8_t b0 = MlasLoadFloat16x8(B), b1; + if constexpr (K == 2) b1 = MlasLoadFloat16x8(B + 8); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0), c01; + if constexpr (K == 2) c01 = vfmaq_lane_f16(c0, b1, a0, 1); + + if constexpr (K == 1) + return c0; + else + return c01; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && K == 8), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x8_t a0 = MlasLoadFloat16x8(A); + + float16x8_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x8(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x8(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x8(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x8(B + ldb * 3); + + float16x8_t c00, c01, c02, c03; + c00 = vmulq_f16(b0, a0); + if constexpr (N > 1) + c01 = vmulq_f16(b1, a0); + else + c01 = MlasZeroFloat16x8(); + if constexpr (N > 2) + c02 = vmulq_f16(b2, a0); + else + c02 = MlasZeroFloat16x8(); + if constexpr (N > 3) + c03 = vmulq_f16(b3, a0); + else + c03 = MlasZeroFloat16x8(); + + Transpose4x8(c00, c01, c02, c03); + + float16x8_t c_low_high = vaddq_f16(vaddq_f16(c00, c01), vaddq_f16(c02, c03)); + float16x4_t c_low = vget_low_f16(c_low_high); + float16x4_t c_high = vget_high_f16(c_low_high); + float16x4_t c = vadd_f16(c_low, c_high); + + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K == 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x4_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x4(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x4(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x4(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x4(B + ldb * 3); + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K > 0 && K < 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasZeroFloat16x4(), b1, b2, b3; + if constexpr (N > 1) b1 = MlasZeroFloat16x4(); + if constexpr (N > 2) b2 = MlasZeroFloat16x4(); + if constexpr (N > 3) b3 = MlasZeroFloat16x4(); + + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + b0 = MlasLoadLaneFloat16x4<0>(B, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<0>(B + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<0>(B + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<0>(B + ldb * 3, b3); + + if constexpr (K >= 2) { + a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + b0 = MlasLoadLaneFloat16x4<1>(B + 1, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 3, b3); + } + + if constexpr (K >= 3) { + a0 = MlasLoadLaneFloat16x4<2>(A + 2, a0); + b0 = MlasLoadLaneFloat16x4<2>(B + 2, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 3, b3); + } + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +typename std::enable_if_t<((CountN >= 1 && CountN <= 16 && ((CountN - 1) & CountN) == 0) && (CountM == 1 || CountM == 2)), void> +HQ4BitGemmKernel_CompFp16_Kernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const _mlas_fp16_* Bias, + _mlas_fp16_* C, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + using RegisterType = typename std::conditional_t<(CountN < 8), float16x4_t, float16x8_t>; + + RegisterType accu00, accu01, accu10, accu11; + constexpr size_t b_step = CountN >= 8 ? 8 : 1; + constexpr size_t N = CountN == 16 ? 8 : CountN; + + if constexpr (CountM == 2) { + accu00 = accu10 = PrepareAccumulator(Bias); + } else { + accu00 = PrepareAccumulator(Bias); + } + if constexpr (CountN == 16) { + if constexpr (CountM == 2) { + accu01 = accu11 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } else { + accu01 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } + } + + size_t k = 0; + for (; k + 8 <= K; k += 8, A += 8, B += b_step * 8) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if (K & 4) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 4, A += 4, B += b_step * 4; + } + + if (K & 2) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 2, A += 2, B += b_step * 2; + } + + if (k < K) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C, accu00); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + 8, accu01); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C, accu00); + } else { + MlasStoreLaneFloat16x4<0>(C, accu00); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + 1, accu00); + } + } + + if constexpr (CountM == 2) { + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C + ldc, accu10); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C + ldc, accu10); + } else { + MlasStoreLaneFloat16x4<0>(C + ldc, accu10); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + ldc + 1, accu10); + } + } + } +} + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + assert(CountM <= 2); + + // 2M_16N is the balance between loop unrolling and register spill. + // More unroll will trigger register spill. + // Less unroll will increase micro kernel dependency path penalty. + // TODO: dequant 16N as continuous segments. Current version dequants 8N. + const auto* a = reinterpret_cast(A); + const auto* b = reinterpret_cast(B); + const auto* bias = reinterpret_cast(Bias); + auto* c = reinterpret_cast<_mlas_fp16_*>(C); + + for (; CountN >= 16; CountN -= 16) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<16, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<16, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 16 * ldb, c += 16; + if (bias) bias += 16; + } + + if (CountN & 8) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<8, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<8, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 8 * ldb, c += 8; + if (bias) bias += 8; + } + + if (CountN & 4) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<4, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<4, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 4 * ldb, c += 4; + if (bias) bias += 4; + } + + if (CountN & 2) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<2, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<2, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 2 * ldb, c += 2; + if (bias) bias += 2; + } + + if (CountN & 1) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<1, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<1, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + } +} +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 13ea8d96c20e4..100d7d47751aa 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -358,6 +358,22 @@ size_t bool ZeroMode ); +#ifdef FORCE_GENERIC_ALGORITHMS +typedef +size_t +(MLASCALL MLAS_GEMM_FLOAT_KERNEL_GENERIC)( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ); +#endif + #else #if defined(__aarch64__) && defined(__linux__) @@ -733,6 +749,10 @@ extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx; +#ifdef FORCE_GENERIC_ALGORITHMS + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelZero; + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelAdd; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F; @@ -1017,17 +1037,24 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; // Float/quantized n-bit integer matrix/matrix multiply dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH; +struct MLAS_QNBIT_GEMM_DISPATCH; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; + +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; + +// +// Rotary embedding dispatch structure. +// +struct MLAS_ROPE_DISPATCH; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // // Quantized depthwise convolution kernels. @@ -1184,10 +1211,12 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + + const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 23d29fd02fa5a..ec572a4150292 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -286,7 +286,11 @@ Return Value: this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #ifndef __APPLE__ +#ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#else // FORCE_GENERIC_ALGORITHMS + this->CastF16ToF32Kernel = nullptr; +#endif // FORCE_GENERIC_ALGORITHMS #endif // __APPLE__ this->NchwcBlockSize = 8; @@ -308,8 +312,11 @@ Return Value: // // Check if the processor supports SSE 4.1 instructions. // - +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x80000) != 0) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41; } @@ -319,7 +326,11 @@ Return Value: // Check if the processor supports the AVX and OSXSAVE features. // +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x18000000) == 0x18000000) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS // // Check if the operating system supports saving SSE and AVX states. @@ -387,7 +398,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; @@ -417,7 +428,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) @@ -458,7 +469,7 @@ Return Value: this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -471,7 +482,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } @@ -531,6 +542,8 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->RopeDispatch = &MlasRopeDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -560,9 +573,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; - - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 1ef5b5f7411f0..bcd878efa681b 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -867,7 +867,8 @@ MlasGemmQuantGetDispatch( { const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) +#if !defined(FORCE_GENERIC_ALGORITHMS) +#if defined(MLAS_TARGET_AMD64_IX86) if (AIsSigned) { GemmQuantDispatch = BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; @@ -895,7 +896,13 @@ MlasGemmQuantGetDispatch( if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; } +#elif defined(MLAS_TARGET_LARCH64) + if (!AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; + } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) if (nullptr == GemmQuantDispatch) { std::stringstream ss; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp similarity index 62% rename from onnxruntime/core/mlas/lib/sqnbitgemm.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm.cpp index a45494ef2e04f..f064a8e1d6a78 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -6,16 +6,16 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.cpp + qnbitgemm.cpp Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + multiplication hardware agnostic entrypoint, MlasQNBitGemmBatch, as well as some SQNBitGemm-related query functions. --*/ -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" #include @@ -23,35 +23,40 @@ Module Name: namespace { -enum SQNBitGemmVariant { +enum QNBitGemmVariant { SQNBitGemmVariantInvalid = -1, // Valid variants SQNBitGemmVariant_BitWidth4_CompFp32 = 0, SQNBitGemmVariant_BitWidth4_CompInt8, + HQNBitGemmVariant_BitWidth4_CompFp16, + HQNBitGemmVariant_BitWidth4_CompInt8, // End of valid variants - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. // Its value is used as an array size. SQNBitGemmVariantCount, }; -SQNBitGemmVariant -GetSQNBitGemmVariant( +QNBitGemmVariant +GetQNBitGemmVariant( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + if (ComputeType == SQNBIT_CompFp32) { return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -61,23 +66,28 @@ GetSQNBitGemmVariant( } // namespace bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return false; } - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case HQNBitGemmVariant_BitWidth4_CompFp16: { + return Dispatch->HQ4BitGemmPackQuantBData != nullptr && + Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && + Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return @@ -94,80 +104,80 @@ namespace { size_t -SQNBitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } return 0; } size_t -SQNBitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; } size_t -SQNBitGemmPerGemmWorkspaceStride( +QNBitGemmPerGemmWorkspaceStride( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } } // namespace size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; @@ -175,21 +185,21 @@ MlasSQNBitGemmBatchWorkspaceSize( } size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q4BitGemmPackQuantBDataSize( N, K, BlkLen, ComputeType ); } @@ -213,12 +223,12 @@ struct PerGemmQuantAWorkspace { }; void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, @@ -227,15 +237,15 @@ MlasSQNBitGemmPackQuantBData( MLAS_THREADPOOL* ThreadPool ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return; } if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -248,6 +258,16 @@ MlasSQNBitGemmPackQuantBData( packed_quant_b, ThreadPool ); + } else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) { + Dispatch->HQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. //assert(QuantBScale == nullptr); @@ -295,22 +315,11 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } } -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - void SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -355,7 +364,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -393,7 +402,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -425,11 +434,84 @@ SQ4BitGemm_CompFp32( } } +void +HQ4BitGemm_CompFp16( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const size_t k_blk_num = MlasDivRoundup(K, BlkLen); + const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ldb = k_blk_num * BlkLen; + const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blk_num); + + const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * qldb; + const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes; + const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias; + + // 32N is the sweet spot of cache utilization. It is machine dependent though. + constexpr size_t StrideM = 2; + constexpr size_t StrideN = 32; + + // TODO(fajin): move allocation up to the op. + size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num + ); + + const MLAS_FP16* a = A; + MLAS_FP16* c = C; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc + ); + } + + a += countM * lda; + c += countM * ldc; + } + + QuantBData += countN * qldb; + QuantBScale += countN * k_blk_num; + QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr; + Bias = Bias ? Bias + countN : nullptr; + C += countN; + } +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -500,10 +582,10 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); @@ -522,10 +604,10 @@ SQ4BitGemm_CompInt8( } } #ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, QuantAScale, @@ -554,26 +636,29 @@ SQ4BitGemm_CompInt8( } } -typedef void(InitializeWorkspaceFn)( +template +void +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ); +template <> void -InitializeWorkspace_CompInt8( +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool @@ -581,8 +666,8 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); @@ -622,61 +707,153 @@ InitializeWorkspace_CompInt8( } } -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(Workspace); + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} + +template +using InitializeWorkspaceFn = std::function* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +)>; -constexpr auto OperationMap = []() { - std::array ops; +template +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant); - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template +using QNBitGemmFn = std::function* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +)>; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; +template +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant); - return ops; -}(); +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: + return SQ4BitGemm_CompFp32; + case SQNBitGemmVariant_BitWidth4_CompInt8: + return SQ4BitGemm_CompInt8; + default: + return nullptr; + } +} + +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompFp16: + return HQ4BitGemm_CompFp16; + default: + return nullptr; + } +} } // namespace +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( const size_t M, const size_t N, const size_t K, const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // // Ensure `Workspace` has correct alignment. // if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); Workspace = reinterpret_cast( (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) ); } - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const auto ComputeOperation = GetQNBitGemm(Variant); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -685,11 +862,11 @@ MlasSQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -756,11 +933,11 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); @@ -769,3 +946,33 @@ MlasSQNBitGemmBatch( } }); } + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h similarity index 71% rename from onnxruntime/core/mlas/lib/sqnbitgemm.h rename to onnxruntime/core/mlas/lib/qnbitgemm.h index 2da336ca2f0ec..eb3d0b44ae3de 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.h + qnbitgemm.h Abstract: @@ -46,24 +46,25 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + // TODO: duplicate code from Q4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); // _mm256_load_si256 requires alignment on a 32-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); } std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; + T* PackedQuantBScale; + T* QuantBBlkSum; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -84,44 +85,45 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) // Kernel dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH { +struct MLAS_QNBIT_GEMM_DISPATCH { // // Quantized B data packing function prototypes. // - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ); @@ -141,15 +143,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -157,15 +159,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // - // CompFp32 kernel function prototypes. + // SQNBIT_CompFp32 kernel function prototypes. // /** @@ -228,10 +230,41 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; // - // CompInt8 kernel function prototypes. + // SQNBIT_CompInt8 kernel function prototypes. // /** @@ -337,4 +370,35 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp similarity index 74% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 3f32cc6c5312d..d05de64e68ec8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.cpp + qnbitgemm_kernel_neon.cpp Abstract: @@ -19,8 +19,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon @@ -34,11 +34,11 @@ namespace // size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType @@ -55,7 +55,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -69,7 +69,7 @@ SQ4BitGemmPackQuantBData( const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t Iterations = N * BlockCountK; // one iteration per block - const size_t SubBlkLen = (ComputeType == CompInt8) + const size_t SubBlkLen = (ComputeType == SQNBIT_CompInt8) ? ((BlkLen == 16) ? 16 : 32) : 16; @@ -126,18 +126,18 @@ SQ4BitGemmPackQuantBData( // size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); @@ -150,15 +150,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { @@ -175,20 +175,27 @@ SQ4BitGemmPerGemmWorkspaceAlignment( // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + } d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h similarity index 69% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index ef9345d7ac484..ccadd24ac1991 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + qnbitgemm_kernel_neon.h Abstract: @@ -30,13 +30,13 @@ namespace sqnbitgemm_neon // // Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. +// Refer to the prototypes in qnbitgemm.h for documentation. // These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// MLAS_QNBIT_GEMM_DISPATCH structure and also be implemented in separate // files. // -// CompFp32 declarations +// SQNBIT_CompFp32 declarations void SQ4BitGemmM1Kernel_CompFp32( @@ -53,7 +53,7 @@ SQ4BitGemmM1Kernel_CompFp32( ); void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -64,7 +64,48 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// CompInt8 declarations +// HQNBIT_CompFp16 declarations +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +); + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +); + +#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) + +// SQNBIT_CompInt8 declarations void QuantizeARow_CompInt8( diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp new file mode 100644 index 0000000000000..1f8f7b240694c --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.cpp + +Abstract: + + This module implements rotary embedding kernels for fp32/16. + +--*/ + +#include "rotary_embedding.h" + +namespace { + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +) { + const size_t half_rotary_emb_dim = rotary_emb_dim / 2; + size_t cache_idx = 0; + bool sign = false; + size_t j = 0; + for (size_t i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i >= half_rotary_emb_dim); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); + } +} + +} // namespace + + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->SRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->SRope(input, sin, cos, dim, interleaved, output); +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->HRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->HRope(input, sin, cos, dim, interleaved, output); +} diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h new file mode 100644 index 0000000000000..352dddccf1025 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -0,0 +1,46 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing rotary embedding. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_ROPE_DISPATCH { + // rotary embedding kernel for fp32 + typedef void(SRope_Fn)( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output + ); + + SRope_Fn* SRope = nullptr; + + // rotary embedding kernel for fp16 + typedef void(HRope_Fn)( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output + ); + + HRope_Fn* HRope = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp new file mode 100644 index 0000000000000..e59a95cd9ee4e --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.cpp + +Abstract: + + This module implements the rotary embedding kernels for ARM NEON. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.HRope = rope_neon::RopeKernel_Fp16; + } +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h new file mode 100644 index 0000000000000..8153f65650f7d --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace rope_neon { + +// Rotary embedding kernel for fp16. Embed one hidden state vector. +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_neon diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..3e2eb8fee0e6e --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -0,0 +1,253 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 rotary embedding kernels for ARM NEON. + +--*/ + +#include +#include + +#include "fp16_common.h" +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +namespace rope_neon { + +namespace { + +template +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +); + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float16x8_t real = MlasLoadFloat16x8(input + i); + float16x8_t imag = MlasLoadFloat16x8(input + j); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + } + for (; i + 3 < half_dim; i += 4, j += 4) { + float16x4_t real = MlasLoadFloat16x4(input + i); + float16x4_t imag = MlasLoadFloat16x4(input + j); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x4(output + i, real_out); + MlasStoreFloat16x4(output + j, imag_out); + } + if (half_dim - i == 3) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + real = MlasLoadLaneFloat16x4<2>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out); + } else if (half_dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + } else if (half_dim - i == 1) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + } +} + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float16x8_t x0 = MlasLoadFloat16x8(input + i); + float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + } + for (; i + 7 < dim; i += 8) { + float16x4_t x0 = MlasLoadFloat16x4(input + i); + float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); + float16x4_t real = vuzp1_f16(x0, x1); + float16x4_t imag = vuzp2_f16(x0, x1); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + float16x4_t y0 = vzip1_f16(real_out, imag_out); + float16x4_t y1 = vzip2_f16(real_out, imag_out); + MlasStoreFloat16x4(output + i, y0); + MlasStoreFloat16x4(output + i + 4, y1); + } + if (dim - i == 6) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); + imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + MlasStoreLaneFloat16x4<2>(output + i + 4, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out); + } else if (dim - i == 4) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + } else if (dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + } +} + +} // namespace + +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output); + + if (interleaved) { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} // namespace rope_neon diff --git a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp index 62729256dac23..cbec5d89bbac7 100644 --- a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp +++ b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp @@ -83,6 +83,8 @@ Return Value: #endif + int countb = 0; + do { float BElements00; @@ -116,6 +118,7 @@ Return Value: // const float* a = A; + const float* b = B; size_t k = CountK; while (k >= 2) { @@ -128,10 +131,10 @@ Return Value: Row1AElements1 = a[lda + 1]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -144,10 +147,10 @@ Return Value: Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; + BElements00 = b[16]; + BElements01 = b[17]; + BElements02 = b[18]; + BElements03 = b[19]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; @@ -161,7 +164,7 @@ Return Value: } a += 2; - B += 8; + b += 32; k -= 2; } @@ -173,10 +176,10 @@ Return Value: Row1AElements0 = a[lda]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -188,8 +191,6 @@ Return Value: Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - - B += 4; } // @@ -295,9 +296,14 @@ Return Value: break; } + B += 4; C += 4; CountN -= 4; + countb = (countb + 1) % 4; + if (countb == 0) { + B += CountK * 16 - 16; + } } while (CountN > 0); return ProcessTwoRows ? 2 : 1; diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 4d7a1ceb4eee7..f8b25fb42caf3 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { @@ -1158,6 +1158,7 @@ Return Value: if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { +#if !defined(FORCE_GENERIC_ALGORITHMS) #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; @@ -1181,6 +1182,7 @@ Return Value: } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) } @@ -1193,7 +1195,7 @@ Return Value: if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) && !defined(FORCE_GENERIC_ALGORITHMS) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index baaa4ba1a3b1f..81615da46aa2e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" @@ -1306,12 +1306,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -1319,9 +1319,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // TODO: always use SubBlkLen = 64 in CompInt8 + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -1330,18 +1330,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; @@ -1349,18 +1349,18 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 80d67806ea6e8..445ead329acf8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index af6f52090adcb..5dab8091ce760 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 174ebc580904c..d4b89bd9bad2d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template @@ -117,7 +117,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 13bd369a065bb..b4e25d4e4040a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx512_int8_blklen16.h" @@ -28,7 +28,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // #include "sqnbitgemm_kernel_avx_common_fp32.h" @@ -151,7 +151,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( } // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // MLAS_FORCEINLINE @@ -332,12 +332,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -346,24 +346,24 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h index 7d9dc36854621..8f1ea6676b788 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -81,7 +81,7 @@ accumulate_blklen32_r2c1blk2_avx2( _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -143,7 +143,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); @@ -184,7 +184,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -435,7 +435,7 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } -template +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, @@ -877,7 +877,7 @@ MLAS_FORCEINLINE QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, - remainingCols, + remainingCols, BlockCountK, Bias ? Bias + multipleCols : nullptr, lda, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 60a887345d0e0..d79554c34c108 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index bb14babd6c2b1..03064886caf24 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e9df6b952bd27..3b1096ac05ba7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2a65ac4af0c1d..72ce28d834199 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" static MLAS_FORCEINLINE __m256 diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6a5c01162c51b..a4468bb906bbc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" @@ -314,12 +314,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -328,7 +328,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -337,18 +337,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 177f5518bb891..b0367b7fb9a15 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" // @@ -7,16 +7,16 @@ // static size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); @@ -39,7 +39,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -304,7 +304,7 @@ PackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -326,18 +326,18 @@ PackQuantBDataAndBlkSum( // static size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch(ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // QuantData + Scale + BlkSum @@ -351,15 +351,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h index 5cd380e591098..d15cfc782e125 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" template MLAS_FORCEINLINE diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 895ce6cd091c2..2e96082968866 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index 12ddc42506e98..31a499b8243af 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompFp32. --*/ @@ -21,8 +21,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { @@ -31,7 +31,7 @@ namespace { // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // MLAS_FORCEINLINE void @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl( } // namespace void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 0d62ea37b7e26..73beb06a3cfad 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8. --*/ @@ -21,15 +21,15 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon { // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 45c3963365e6b..941b884d0b9d2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index e9c3812bde899..ed78dfa67042d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 267a82b72670c..935114c40d1a7 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "onnx/defs/shape_inference.h" #include "onnx/defs/tensor_proto_util.h" #include "core/framework/tensorprotoutils.h" @@ -767,7 +768,8 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No } // check where has X=-Infinity - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where.InputDefs()[1]), -INFINITY, true)) { + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where.InputDefs()[1]), + -std::numeric_limits::infinity(), true)) { DEBUG_LOG("where const not matched."); return false; } diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1466de51d0b99..e755b4bfa6364 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - is_sparse_initializer_check); + is_sparse_initializer_check, logger); #else // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, + logger); #endif std::vector fetch_mlvalue_idxs; diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index e8e395678436e..103e72072f713 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -823,6 +823,12 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l !graph_utils::IsSupportedProvider(layer_norm_node, GetCompatibleExecutionProviders())) { continue; } + + // The third input (beta) is optional in LayerNormalization-17 of onnx domain. Make sure 3 inputs are available. + if (!(layer_norm_node.InputDefs().size() == 3 && layer_norm_node.InputDefs()[2]->Exists())) { + continue; + } + // Find Attention after LayerNormalization const Node* p_attention = graph_utils::FirstChildByType(layer_norm_node, "Attention"); if (p_attention == nullptr) { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f769d31092d19..ba2b87b5aa0ca 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -63,6 +63,7 @@ #ifdef MLAS_TARGET_AMD64_IX86 #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -189,6 +190,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -243,6 +245,7 @@ InlinedVector> GenerateTransformers( if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); // EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input // condition that it ensures is important for the partitioning that happens after Level1 optimizers are run. @@ -402,7 +405,8 @@ InlinedVector> GenerateTransformers( } auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } @@ -435,6 +439,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -488,7 +493,8 @@ InlinedVector> GenerateTransformersForMinimalB #ifndef DISABLE_CONTRIB_OPS AllocatorPtr cpu_allocator = std::make_shared(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab41d..b1665c7172549 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, + const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const KernelCreateInfo* kernel_create_info{}; const auto lookup_status = cpu_kernel_registry.TryFindKernel( kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, &kernel_create_info); + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return true; } @@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: return false; } -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { return dst_bit_length <= src_bit_length; } - if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || + (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { return true; } @@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 7953cde6686c0..56f7d28cd5b77 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -177,7 +177,11 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid for (size_t i = 2; i < node->Inputs().size(); i++) { auto constant = api_graph->GetConstant(node->Inputs()[i]); if (constant != nullptr && constant->Data().size() > 0) { - input_perms.push_back(&input_perm); + // Starting from opset version 18, the 'scales' and 'sizes' can be any length up to the input rank. + // However, our current implementation only supports the transposition of 4D tensors. + if (constant->NumElements() == 4) { + input_perms.push_back(&input_perm); + } } else { // TODO: Fix inconsistency. We should Transpose the non-const inputs so that the result of our changes // is consistent - all layout specific inputs are in NHWC format when we're done. diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index e944522c9c338..6b76dc626fba0 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } + // Checks the first input of MatMul has 2 dimensions. + // The test for the second input is done in method Apply as it accesses the constant. + if (node.InputDefs()[0] == nullptr) { + // This should never happen but just in case. + return false; + } + auto shape_a = node.InputDefs()[0]->Shape(); + if (shape_a == nullptr) { + // We cannot shape the rank. It is better to avoid fusing. + return false; + } + if (shape_a->dim_size() != 2) { + // Gemm only supports 2D tensors. + return false; + } + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. const auto& output_defs = batch_norm_node->OutputDefs(); if (output_defs.size() > 1) { @@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias_tensor->dims_size() != 1 || mean_tensor->dims_size() != 1 || var_tensor->dims_size() != 1 || + matmul_b_tensor->dims_size() != 2 || scale_tensor->dims(0) != matmul_b_tensor->dims(1) || bias_tensor->dims(0) != matmul_b_tensor->dims(1) || mean_tensor->dims(0) != matmul_b_tensor->dims(1) || diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 4fee1a6ce224e..b619efb2f751e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -49,6 +49,49 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { return data_type == actual_data_type; } +// Return total mnumber of Elements. +static uint64_t NumElements(const TensorShapeProto* tensor_shape) { + if (nullptr == tensor_shape || tensor_shape->dim_size() < 1) { + return 0; + } + uint64_t num_elements = 1; + + for (int i = 0; i < tensor_shape->dim_size(); i++) { + num_elements *= tensor_shape->dim(i).dim_value(); + } + return num_elements; +} + +bool CheckMatMulLargeTensors(const Node& matmulinteger_node, const Node& cast_node) { + const auto a_def = matmulinteger_node.InputDefs()[0]; + const auto b_def = matmulinteger_node.InputDefs()[1]; + const int a_dim_size = a_def->Shape()->dim_size(); + const int b_dim_size = b_def->Shape()->dim_size(); + uint64_t a_num_elements = NumElements(a_def->Shape()); + uint64_t b_num_elements = NumElements(b_def->Shape()); + + if (a_dim_size != b_dim_size) { + bool a_is_broadcasted = a_dim_size < b_dim_size; + if (a_is_broadcasted) { + for (int i = 0; i < b_dim_size - a_dim_size; i++) { + a_num_elements *= b_def->Shape()->dim(i).dim_value(); + } + } else { + for (int i = 0; i < a_dim_size - b_dim_size; i++) { + b_num_elements *= a_def->Shape()->dim(i).dim_value(); + } + } + } + + int output_data_type = HasElementDataType(*cast_node.OutputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) ? 2 : 4; + uint64_t total_bytes = (a_num_elements + b_num_elements) * output_data_type; + + if (total_bytes > UINT32_MAX) { + return true; + } + return false; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -114,6 +157,17 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g continue; } + const Node* p_dynamicquantize_node = graph_utils::FirstParentByType(*p_matmulinteger_node, "DynamicQuantizeLinear"); + + // Check MatMulInteger Nodes' input is coming from DynamicQuantizeLinear + // For larger tensors DynamicQuantizeLinear -> MatMulInteger is used to be resource efficient + // And we have better MatMulInteger Metacommand coverage in DML + if (is_dml_ep && p_dynamicquantize_node) { + if (CheckMatMulLargeTensors(matmulinteger_node, cast_node)) { + continue; + } + } + // Find bias node Node* p_add_node = nullptr; if (optimizer_utils::CheckOutputEdges(graph, mul_node, 1)) { diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index ee79fa620374e..cd654991c92d5 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -44,7 +44,9 @@ NhwcConvLookup( return &(iter->second); } -NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept +NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, + std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { if (!cpu_kernel_registry) { // This is a CPU op nodes optimizer, not useful if cpu EP is not available. @@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_, - qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info); + qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_, - qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info); + qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, - nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info); + nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, - nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info); + nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, - nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, - nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index 000732060b889..c65f851fdab9d 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed. class NhwcTransformer : public GraphTransformer { private: public: - explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept; + explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept; /** * @brief Usually called right after constructor, it shows whether diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index ed7d5feb2beb3..b2e8e491c361c 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& /* model_path */, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const { std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; - return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out); + return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out); } static Status TryCreateKernel(const Node& node, @@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node, FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, const ConfigOptions& config_options, + const logging::Logger& logger, /*out*/ std::unique_ptr& op_kernel) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, + ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger, &kernel_create_info)); static const AllocatorMap dummy_allocators; @@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, + ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_, op_kernel); // Kernel found in the CPU kernel registry diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index b0f7f461661b5..24a23312feba9 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame { const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); ~Info() = default; @@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; + const logging::Logger& logger_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info); }; diff --git a/onnxruntime/core/optimizer/pre_shape_node_elimination.cc b/onnxruntime/core/optimizer/pre_shape_node_elimination.cc index 23980c9c10e6b..8f50ef7c09c95 100644 --- a/onnxruntime/core/optimizer/pre_shape_node_elimination.cc +++ b/onnxruntime/core/optimizer/pre_shape_node_elimination.cc @@ -48,7 +48,7 @@ bool PreShapeNodeElimination::SatisfyCondition(const Graph& graph, const Node& n for (const Node* next_node : output_nodes) { // Check if the next node is not of type "Shape" - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Shape", {13, 15, 19}, kOnnxDomain)) { + if (!next_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Shape", {13, 15, 19}, kOnnxDomain)) { return false; } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 18e462c04dff3..5538aa54801cc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion( return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end(); } -static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { +static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) { constexpr size_t w_idx = 1; constexpr size_t w_zp_idx = 9; constexpr size_t r_idx = 2; @@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) || !graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) || r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " cannot locate recurrence tensor of const int8 type," << " int8 overflow might impact precision !"; return false; @@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) || !graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) || r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " unable to locate recurrence tensor or its zero point value," << " int8 overflow might impact precision !"; return false; @@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int if (graph_utils::IsSupportedOptypeVersionAndDomain( op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) { // This one has two set of quantized arguments - modified |= TryConvertDynamicQuantizeLSTM(op_node, graph); + modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger); continue; // go on to next operator node } diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc new file mode 100644 index 0000000000000..9e9665e14ede4 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/bias_quantization.h" + +#include "core/common/common.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { + +Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + const auto& input_defs = node.InputDefs(); + + // It's Conv/Gemm node with an initializer bias. + if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() || + !graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) { + continue; + } + + auto bias_shape = input_defs[2]->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1) { + continue; + } + int64_t bias_size = bias_shape->dim(0).dim_value(); + + // input_0 and input_1 are outputs of DequantizeLinear nodes. + const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name()); + const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name()); + if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName || + parent_node_1->OpType() != QDQ::DQOpName) { + continue; + } + + Node& dq_0 = *graph.GetNode(parent_node_0->Index()); + Node& dq_1 = *graph.GetNode(parent_node_1->Index()); + + // Currently we require input_0 is per-tensor scale. + if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) { + continue; + } + + // For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm). + bool is_per_tensor_scale = true; + if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) { + is_per_tensor_scale = false; + auto weight_scale_shape = dq_1.InputDefs()[1]->Shape(); + if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() || + weight_scale_shape->dim(0).dim_value() != bias_size) { + continue; + } + + const auto& dq_attrs = dq_1.GetAttributes(); + if (dq_attrs.find("block_size") != dq_attrs.end()) { + continue; + } + + int64_t axis = 1; + if (dq_attrs.find("axis") != dq_attrs.end()) { + axis = dq_attrs.at("axis").i(); + } + + int64_t expected_axis = 0; + if (node.OpType() == "Gemm") { + int64_t transB = 0; + if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) { + transB = attr->second.i(); + } + expected_axis = transB == 0 ? 1 : 0; + } + + if (axis != expected_axis) { + continue; + } + } + + // Bias is quantized to int32. + ONNX_NAMESPACE::TypeProto int32_type_proto; + int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale. + ONNX_NAMESPACE::TypeProto bias_dq_type; + bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type()); + bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size); + + // scale = input_scale_0 * input_scale_1. + NodeArg& scale_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type); + Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node", + {dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr, + node.Domain()); + + // fp_bias / scale. + NodeArg& bias_div_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type); + Node& div_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", + {node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain()); + graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); + + // Round(fp_bias / scale). + NodeArg& bias_div_round_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type); + Node& round_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", + {&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain()); + graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); + + // Cast(round(fp_bias / scale)) to int32. + NodeArg& bias_int32_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto); + Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node", + {&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain()); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); + graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); + + // Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0. + NodeArg& bias_dq_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type); + Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", + {&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain()); + if (!is_per_tensor_scale) { + dq_node.AddAttribute("axis", static_cast(0)); + } + + graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0); + graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1); + node.MutableInputDefs()[2] = &bias_dq_node_arg; + graph.AddEdge(dq_node.Index(), node.Index(), 0, 2); + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h new file mode 100644 index 0000000000000..0297def260fd9 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * @class BiasQuantization + * + * Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias + * with scale = scale_input_0 * scale_input_1 and zero_point = 0. + * + * Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed + * by a DequantizeLinear node. + */ +class BiasQuantization : public GraphTransformer { + public: + BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 71d4ddd38913b..2f98711771f1b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -68,10 +68,14 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative // scale will change the ordering of the elements between quantized & de-quantized values. std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider}; + + // We don't drop the resample QDQ ops here for DML because we don't know yet whether it is allowed to be executed in DML. + // This will be done within DML during a graph pass if allowed, but otherwise we need to keep the dequantize op alive. + std::vector cpu_ep = {kCpuExecutionProvider}; std::unique_ptr selector_no_16bit = std::make_unique(false, false, true, - providers); + cpu_ep); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, {{"Resize", {}}}, std::move(selector_no_16bit), @@ -143,7 +147,7 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - std::vector providers = {kCpuExecutionProvider}; + std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider}; std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, @@ -232,7 +236,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_ #if !defined(ORT_MINIMAL_BUILD) // TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit. - std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider}; + std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider}; std::unique_ptr selector = std::make_unique(is_int8_allowed, false, false, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d2240b5d50194..81305f7effa16 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -291,7 +291,8 @@ SelectorManager::SelectorManager() { InitializeSelectorsMap(); } -std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { +std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer, + const logging::Logger& logger) const { std::vector qdq_selections; for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { const auto* node = graph_viewer.GetNode(index); @@ -313,7 +314,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { - LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); + LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType(); continue; } } @@ -329,7 +330,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap } std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) { // Get QDQ NodeUnits first QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger); for (const auto& qdq_selection : qdq_selections) { auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index f388206551172..ccc1844e3e985 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -15,7 +15,9 @@ #endif namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; class Node; @@ -65,7 +67,7 @@ class SelectorManager { // Methods that finds and returns a vector of QDQ::NodeGroup in a given graph // Can be used in QDQ support in different EPs - std::vector GetQDQSelections(const GraphViewer& graph_viewer) const; + std::vector GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const; private: Selectors qdq_selectors_; @@ -88,7 +90,7 @@ class SelectorManager { // We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer // library whereas it should be able to be used by an EP with no dependency on optimizers. std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index f1e94dd4fe9e4..8c0136c495403 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -17,13 +17,22 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); + bool ModifyGraph(const KernelRegistryManager& schema_registries, + const logging::Logger& logger, + int& copy_node_counter); private: - void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); - void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); + void ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); + void BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger); void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); - bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); + bool ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl); @@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // find defs that require copy for (auto& node : graph_.Nodes()) { // as we process the defs, collect all the initializers consumed at the current graph level - ProcessDefs(node, kernel_registries, initializers_consumed); + ProcessDefs(node, kernel_registries, initializers_consumed, logger); } // for initializers shared by different providers, create dups - if (ProcessInitializers(kernel_registries, initializers_consumed)) + if (ProcessInitializers(kernel_registries, initializers_consumed, logger)) modified = true; for (auto arg : graph_.GetInputs()) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_input_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_output_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : graph_.GetInputs()) // For inputs we need to create a copy node only when the input is connected to both provider @@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, - InitializedTensorSet& initializers_consumed) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { auto node_provider_type = node.GetExecutionProviderType(); if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || @@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci)); bool is_implicit_input = false; auto process_inputs = @@ -256,13 +267,6 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } } else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider && node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) { - // TODO: copy between devices? i.e. multiple GPUs - if (node_provider_type != onnxruntime::kCpuExecutionProvider && - node_provider_type != onnxruntime::kVitisAIExecutionProvider && - !node_provider_type.empty()) { - ORT_THROW("Execution type '", node_provider_type, "' doesn't support memcpy "); - } - for (const auto* arg : node.InputDefs()) { if (arg->Exists()) non_provider_input_defs_.insert(arg); @@ -285,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } // for non_provider defs, collect the nodes that expect it is provider tensor as input/output. -void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) { +void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; auto input_it = @@ -303,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } @@ -358,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co // We duplicate any initializer that is used by both provider nodes and non-provider nodes // to ensure that provider nodes and non-provider nodes don't share initializers, as they // need to stay in different memory locations. -bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) { +bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { std::map replacements; for (const auto& pair : initializers_consumed) { const auto& name = pair.first; @@ -390,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker auto dup_replacements = replacements; const KernelCreateInfo* kci = nullptr; - auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci); + auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); if (kci == nullptr) continue; if (kci->kernel_def == nullptr) continue; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 470838d36ec1c..10cb6eb97bdd6 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -1653,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) { constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit}; -static bool HandleConcat(HandlerArgs& args) { +bool HandleConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat}; // Handles Softmax, Hardmax, and LogSoftmax -static bool HandleSoftHardMax(HandlerArgs& args) { +bool HandleSoftHardMax(HandlerArgs& args) { if (args.ctx.opset >= 13) { return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 0095ead75f0c8..f65bd6aa82fbb 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args); // Transposes all inputs and all outputs. Updates axis attribute. bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_axis = std::nullopt); +bool HandleConcat(HandlerArgs& args); +bool HandleSoftHardMax(HandlerArgs& args); + // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); bool HandleResize([[maybe_unused]] HandlerArgs& args); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index 8eaac3d34c3af..824ab20a84668 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) { constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; -static bool HandleQLinearConcat(HandlerArgs& args) { - return HandleSimpleNodeWithAxis(args); -} - std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { (void)ctx; std::vector indices; @@ -48,11 +44,7 @@ std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { return indices; } -constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat}; - -static bool HandleQLinearBinaryOp(HandlerArgs& args) { - return HandleSimpleNodeBroadcast(args); -} +constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat}; std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { // Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point], @@ -60,7 +52,7 @@ std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { return {0, 3}; } -constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp}; +constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast}; static bool HandleQLinearPoolOp(HandlerArgs& args) { // Swap between channel first/last variants. Only works for applicable values of perm. @@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; +constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, &HandleContribQuantizeDequantizeLinear}; @@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() { {"com.microsoft.QLinearMul", q_linear_binary_op_handler}, {"com.microsoft.QLinearReduceMean", reduce_op_handler}, {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, + {"com.microsoft.QLinearSoftmax", soft_hard_max_handler}, }; return map; diff --git a/onnxruntime/core/platform/posix/ort_mutex.cc b/onnxruntime/core/platform/posix/ort_mutex.cc deleted file mode 100644 index e124ce168085f..0000000000000 --- a/onnxruntime/core/platform/posix/ort_mutex.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" -#include "core/platform/ort_mutex.h" -#include -#include -#include - -namespace onnxruntime { -void OrtCondVar::timed_wait_impl(std::unique_lock& lk, - std::chrono::time_point tp) { - using namespace std::chrono; -#ifndef NDEBUG - if (!lk.owns_lock()) - ORT_THROW("condition_variable::timed wait: mutex not locked"); -#endif - nanoseconds d = tp.time_since_epoch(); - timespec abs_deadline; - seconds s = duration_cast(d); - using ts_sec = decltype(abs_deadline.tv_sec); - constexpr ts_sec ts_sec_max = std::numeric_limits::max(); - if (s.count() < ts_sec_max) { - abs_deadline.tv_sec = static_cast(s.count()); - abs_deadline.tv_nsec = static_cast((d - s).count()); - } else { - abs_deadline.tv_sec = ts_sec_max; - abs_deadline.tv_nsec = 999999999; - } - nsync::nsync_cv_wait_with_deadline(&native_cv_object, lk.mutex()->native_handle(), abs_deadline, nullptr); -} - -void OrtCondVar::wait(std::unique_lock& lk) { -#ifndef NDEBUG - if (!lk.owns_lock()) { - ORT_THROW("OrtCondVar wait failed: mutex not locked"); - } -#endif - nsync::nsync_cv_wait(&native_cv_object, lk.mutex()->native_handle()); -} - -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc index bf3b53afbd7d3..7464ab4c57d01 100644 --- a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -1,7 +1,8 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "hardware_core_enumerator.h" +#include "core/platform/windows/env.h" #include #include #include @@ -83,6 +84,38 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetCoreInfo(); +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + bool isIntelSpecifiedPlatform = false; + const int kVendorID_IntelSpecifiedPlatformIDs[3] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + + int regs_leaf0[4]; + int regs_leaf1[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf1, 0x1); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); + + for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { + isIntelSpecifiedPlatform = true; + } + } + + if (isIntel) { + if (isIntelSpecifiedPlatform) { + // We want to exclude cores without an LLC + return cores.LLCCores; + } else { + return cores.PhysicalCores; + } + } +#endif return cores.LLCCores; } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index 889bc6fcf86df..950ac247a2046 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -65,12 +65,12 @@ EtwRegistrationManager& EtwRegistrationManager::Instance() { } bool EtwRegistrationManager::IsEnabled() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return is_enabled_; } UCHAR EtwRegistrationManager::Level() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return level_; } @@ -94,7 +94,7 @@ Severity EtwRegistrationManager::MapLevelToSeverity() { } ULONGLONG EtwRegistrationManager::Keyword() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return keyword_; } @@ -103,12 +103,12 @@ HRESULT EtwRegistrationManager::Status() const { } void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) { - std::lock_guard lock(callbacks_mutex_); + std::lock_guard lock(callbacks_mutex_); callbacks_.push_back(&callback); } void EtwRegistrationManager::UnregisterInternalCallback(const EtwInternalCallback& callback) { - std::lock_guard lock(callbacks_mutex_); + std::lock_guard lock(callbacks_mutex_); auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), [&callback](const EtwInternalCallback* ptr) { return ptr == &callback; @@ -126,7 +126,7 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( _In_opt_ PVOID CallbackContext) { auto& manager = EtwRegistrationManager::Instance(); { - std::lock_guard lock(manager.provider_change_mutex_); + std::lock_guard lock(manager.provider_change_mutex_); manager.is_enabled_ = (IsEnabled != 0); manager.level_ = Level; manager.keyword_ = MatchAnyKeyword; @@ -135,11 +135,11 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( } EtwRegistrationManager::~EtwRegistrationManager() { - std::lock_guard lock(callbacks_mutex_); + std::lock_guard lock(callbacks_mutex_); callbacks_.clear(); if (initialization_status_ == InitializationStatus::Initialized || initialization_status_ == InitializationStatus::Initializing) { - std::lock_guard init_lock(init_mutex_); + std::lock_guard init_lock(init_mutex_); assert(initialization_status_ != InitializationStatus::Initializing); if (initialization_status_ == InitializationStatus::Initialized) { ::TraceLoggingUnregister(etw_provider_handle); @@ -153,13 +153,16 @@ EtwRegistrationManager::EtwRegistrationManager() { void EtwRegistrationManager::LazyInitialize() { if (initialization_status_ == InitializationStatus::NotInitialized) { - std::lock_guard lock(init_mutex_); + std::lock_guard lock(init_mutex_); if (initialization_status_ == InitializationStatus::NotInitialized) { // Double-check locking pattern initialization_status_ = InitializationStatus::Initializing; etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); if (FAILED(etw_status_)) { + // Registration can fail when running under Low Integrity process, and should be non-fatal initialization_status_ = InitializationStatus::Failed; - ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); + // Injection of ETW logger can happen very early if ETW provider was already listening. + // Don't use LOGS_DEFAULT here or can get "Attempt to use DefaultLogger but none has been registered" + std::cerr << "Error in ETW registration: " << std::to_string(etw_status_) << std::endl; } initialization_status_ = InitializationStatus::Initialized; } @@ -174,9 +177,11 @@ void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, return; } - std::lock_guard lock(callbacks_mutex_); + std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { - (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + if (callback != nullptr) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index d6c9ea27b2955..2a798a28f13de 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -24,7 +24,7 @@ #include "core/common/logging/capture.h" #include "core/common/logging/isink.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { namespace logging { @@ -98,9 +98,9 @@ class EtwRegistrationManager { _In_opt_ PVOID CallbackContext); std::vector callbacks_; - OrtMutex callbacks_mutex_; - mutable OrtMutex provider_change_mutex_; - OrtMutex init_mutex_; + std::mutex callbacks_mutex_; + mutable std::mutex provider_change_mutex_; + std::mutex init_mutex_; InitializationStatus initialization_status_ = InitializationStatus::NotInitialized; bool is_enabled_; UCHAR level_; diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index 3401507ae911f..cc23d70c0f11f 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -30,7 +30,6 @@ class CaptureStackTrace { // Get the stack trace. Currently only enabled for a DEBUG build as we require the DbgHelp library. std::vector GetStackTrace() { #ifndef NDEBUG -// TVM need to run with shared CRT, so won't work with debug helper now #if (defined __cpp_lib_stacktrace) && !(defined _OPSCHEMA_LIB_) && !(defined _GAMING_XBOX) && !(defined ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) return detail::CaptureStackTrace().Trace(); #else diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 86067d377205b..47789af9d5a47 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" -#include "core/platform/ort_mutex.h" +#include #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -57,18 +57,18 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim #pragma warning(pop) #endif -OrtMutex WindowsTelemetry::mutex_; -OrtMutex WindowsTelemetry::provider_change_mutex_; +std::mutex WindowsTelemetry::mutex_; +std::mutex WindowsTelemetry::provider_change_mutex_; uint32_t WindowsTelemetry::global_register_count_ = 0; bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; std::vector WindowsTelemetry::callbacks_; -OrtMutex WindowsTelemetry::callbacks_mutex_; +std::mutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); if (global_register_count_ == 0) { // TraceLoggingRegister is fancy in that you can only register once GLOBALLY for the whole process HRESULT hr = TraceLoggingRegisterEx(telemetry_provider_handle, ORT_TL_EtwEnableCallback, nullptr); @@ -79,7 +79,7 @@ WindowsTelemetry::WindowsTelemetry() { } WindowsTelemetry::~WindowsTelemetry() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); if (global_register_count_ > 0) { global_register_count_ -= 1; if (global_register_count_ == 0) { @@ -87,22 +87,22 @@ WindowsTelemetry::~WindowsTelemetry() { } } - std::lock_guard lock_callbacks(callbacks_mutex_); + std::lock_guard lock_callbacks(callbacks_mutex_); callbacks_.clear(); } bool WindowsTelemetry::IsEnabled() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return enabled_; } UCHAR WindowsTelemetry::Level() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return level_; } UINT64 WindowsTelemetry::Keyword() const { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); return keyword_; } @@ -111,12 +111,12 @@ UINT64 WindowsTelemetry::Keyword() const { // } void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { - std::lock_guard lock_callbacks(callbacks_mutex_); + std::lock_guard lock_callbacks(callbacks_mutex_); callbacks_.push_back(&callback); } void WindowsTelemetry::UnregisterInternalCallback(const EtwInternalCallback& callback) { - std::lock_guard lock_callbacks(callbacks_mutex_); + std::lock_guard lock_callbacks(callbacks_mutex_); auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), [&callback](const EtwInternalCallback* ptr) { return ptr == &callback; @@ -132,7 +132,7 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - std::lock_guard lock(provider_change_mutex_); + std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; @@ -143,7 +143,7 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { - std::lock_guard lock_callbacks(callbacks_mutex_); + std::lock_guard lock_callbacks(callbacks_mutex_); for (const auto& callback : callbacks_) { (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index ed80f13e633ac..b23a60a44b5f0 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -8,7 +8,7 @@ #include "core/platform/telemetry.h" #include #include -#include "core/platform/ort_mutex.h" +#include #include "core/platform/windows/TraceLoggingConfig.h" namespace onnxruntime { @@ -69,14 +69,14 @@ class WindowsTelemetry : public Telemetry { static void UnregisterInternalCallback(const EtwInternalCallback& callback); private: - static OrtMutex mutex_; + static std::mutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; static std::vector callbacks_; - static OrtMutex callbacks_mutex_; - static OrtMutex provider_change_mutex_; + static std::mutex callbacks_mutex_; + static std::mutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; diff --git a/onnxruntime/core/providers/cann/cann_allocator.h b/onnxruntime/core/providers/cann/cann_allocator.h index 15fa7b177904a..1022374b51d9f 100644 --- a/onnxruntime/core/providers/cann/cann_allocator.h +++ b/onnxruntime/core/providers/cann/cann_allocator.h @@ -6,7 +6,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 9a242919665bb..f954baf3eabae 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -28,7 +28,7 @@ using onnxruntime::common::Status; namespace onnxruntime { // Models can only be parsed and built serially in the same process -OrtMutex g_mutex; +std::mutex g_mutex; class Memcpy final : public OpKernel { public: @@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node); if (cann_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() - << " node name: " << node.Name(); + LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() + << " node name: " << node.Name(); continue; } candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger()); for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) continue; @@ -1389,7 +1389,7 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse if (modelIDs_.find(filename) != modelIDs_.end()) { modelID = modelIDs_[filename]; } else { - std::lock_guard lock(g_mutex); + std::lock_guard lock(g_mutex); if (cann::FileExist(filename_with_suffix)) { CANN_RETURN_IF_ERROR(aclmdlLoadFromFile(filename_with_suffix.c_str(), &modelID)); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index d83bd88d6958f..7debfa72778fd 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -12,7 +12,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cann/cann_execution_provider_info.h" #include "core/providers/cann/cann_inc.h" #include "core/providers/cann/cann_utils.h" diff --git a/onnxruntime/core/providers/cann/cann_kernel.h b/onnxruntime/core/providers/cann/cann_kernel.h index 90180144202a7..5effbb4f56043 100644 --- a/onnxruntime/core/providers/cann/cann_kernel.h +++ b/onnxruntime/core/providers/cann/cann_kernel.h @@ -4,7 +4,7 @@ #pragma once -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cann/cann_inc.h" #include "core/providers/cann/cann_call.h" #include "core/providers/cann/cann_execution_provider.h" diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index e1f148fa93e23..38ac629331749 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -24,11 +24,12 @@ namespace coreml { OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, int32_t coreml_version, - uint32_t coreml_flags) { + bool only_allow_static_input_shapes, + bool create_mlprogram) { return OpBuilderInputParams{graph_viewer, coreml_version, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, - (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; + only_allow_static_input_shapes, + create_mlprogram}; } const IOpBuilder* GetOpBuilder(const Node& node) { @@ -133,13 +134,13 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap return true; } -bool HasNeuralEngine(const logging::Logger& logger) { +bool HasNeuralEngine() { bool has_neural_engine = false; #ifdef __APPLE__ struct utsname system_info; uname(&system_info); - LOGS(logger, VERBOSE) << "Current Apple hardware info: " << system_info.machine; + LOGS_DEFAULT(VERBOSE) << "Current Apple hardware info: " << system_info.machine; #if TARGET_OS_IPHONE // utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2". @@ -163,7 +164,7 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + LOGS_DEFAULT(INFO) << "HasNeuralEngine running on non-Apple hardware. " "Returning true to enable model conversion and local testing of CoreML EP implementation. " "No CoreML model will be compiled or run."; has_neural_engine = true; diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index 0acaa0dd8a4a3..ae7f3bdbc31a9 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -25,7 +25,8 @@ namespace coreml { OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, int32_t coreml_version, - uint32_t coreml_flags); + bool only_allow_static_input_shapes, + bool create_mlprogram); const IOpBuilder* GetOpBuilder(const Node& node); @@ -45,7 +46,7 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap // CoreML is more efficient running using Apple Neural Engine // This is to detect if the current system has Apple Neural Engine -bool HasNeuralEngine(const logging::Logger& logger); +bool HasNeuralEngine(); } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 5389eb5ab7e95..4481a5172966b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -40,6 +40,25 @@ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, con } namespace { + +template +void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, + std::vector& alpha_values) { + // add slope initializer as alpha weight + const auto& slope_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name()); + Initializer unpacked_tensor(slope_tensor); + const auto alpha_v = unpacked_tensor.DataAsSpan(); + + if (alpha_v.size() == 1) { + // expand to number of channels + std::vector x_shape; + GetShape(*node.InputDefs()[0], x_shape, logger); + alpha_values.resize(x_shape[x_shape.size() - 3], alpha_v[0]); + } else { + alpha_values.assign(alpha_v.begin(), alpha_v.end()); + } +} + Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, COREML_SPEC::ActivationPReLU& prelu) { @@ -84,6 +103,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation std::string_view coreml_op_type; bool add_alpha = false; + bool add_gelu_mode = false; if (op_type == "Sigmoid") { coreml_op_type = "sigmoid"; } else if (op_type == "Tanh") { @@ -93,6 +113,12 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } else if (op_type == "LeakyRelu") { coreml_op_type = "leaky_relu"; add_alpha = true; + } else if (op_type == "Gelu") { + coreml_op_type = "gelu"; + add_gelu_mode = true; + } else if (op_type == "PRelu") { + coreml_op_type = "prelu"; + add_alpha = true; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -102,16 +128,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); if (add_alpha) { - NodeAttrHelper helper(node); - const auto alpha = helper.Get("alpha", 0.01f); - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + if ("PRelu" == op_type) { + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } else { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } } else { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } } + if (add_gelu_mode) { + NodeAttrHelper helper(node); + std::string approximate = helper.Get("approximate", std::string("none")); + if (approximate == "tanh") { + approximate = "TANH_APPROXIMATION"; + } else if (approximate == "none") { + approximate = "EXACT"; + } + AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate))); + } AddOperationOutput(*op, *node.OutputDefs()[0]); @@ -213,17 +262,11 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp const logging::Logger& logger) const { const auto& op_type = node.OpType(); -#if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram) { - if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable - return false; - } - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { - if (op_type == "PRelu") { - return IsPReluOpSupported(node, input_params, logger); - } + if (op_type == "Gelu" && !input_params.create_mlprogram) { + return false; + } + if (op_type == "PRelu") { + return IsPReluOpSupported(node, input_params, logger); } return true; @@ -245,6 +288,7 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration "Relu", "PRelu", "LeakyRelu", + "Gelu", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index bc8b2d1a3505d..6169090a36014 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -15,6 +16,9 @@ class ArgMaxOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -24,41 +28,60 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); - const auto keepdims = helper.Get("keepdims", 1); + const int64_t axis = helper.Get("axis", 0); + const int64_t keepdims = helper.Get("keepdims", 1); const bool removedim = keepdims != 1; - auto* coreml_argmax = layer->mutable_argmax(); - coreml_argmax->set_axis(axis); - coreml_argmax->set_removedim(removedim); - - // There are two cases here: - // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input - // (We still have this special case here because CoreML model does not have Cast) - // 2. Otherwise, we add Argmax layer normally - if (node.GetOutputEdgesCount() == 1) { - auto it = node.OutputEdgesBegin(); - const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); - // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) - // so we omit the check here - if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { - // Skip the cast's input/argmax's output - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); - return Status::OK(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction + + std::unique_ptr op = model_builder.CreateOperation(node, "reduce_argmax"); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims))); + + int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; + // the output of ArgMax must be int32 + AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_argmax = layer->mutable_argmax(); + coreml_argmax->set_axis(axis); + coreml_argmax->set_removedim(removedim); + + // There are two cases here: + // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input + // (We still have this special case here because CoreML model does not have Cast) + // 2. Otherwise, we add Argmax layer normally + if (node.GetOutputEdgesCount() == 1) { + auto it = node.OutputEdgesBegin(); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); + // If Argmax's successive node is a Cast from int64 to int32 output + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { + // Skip the cast's input/argmax's output + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + return Status::OK(); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // Attribute `select_last_index` of ArgMax op is not supported NodeAttrHelper helper(node); @@ -68,6 +91,12 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + return true; + } +#endif + // If there are multiple downstream nodes and cast (toint32) is one of them // not supported, exit here // Otherwise, for general multiple downstream nodes, supported diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index f185a80de3cbf..2817f34bc64f2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -13,15 +13,6 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to -// filter suppported ones. -static std::set Float16Ops = { - "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", - "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", - "Clip", "DepthToSpace", "Resize", "Slice", "Conv", - "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", - "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; - namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, @@ -65,20 +56,27 @@ bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& } if (!HasSupportedOpSet(node, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support this opset"; return false; } if (!HasSupportedInputs(node, input_params, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] has unsupported inputs"; return false; } // We do not support external initializers for now const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); if (HasExternalInitializer(initializers, node, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] has external initializers"; return false; } - return IsOpSupportedImpl(node, input_params, logger); + if (!IsOpSupportedImpl(node, input_params, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] is not supported by the impl"; + return false; + } + return true; } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, @@ -115,13 +113,10 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, return true; } -// only support MLProgram for FP16 -#if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && - Float16Ops.count(node.OpType())) { + // only MLProgram support FP16 + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { return true; } -#endif LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 8da58f659acf1..442194cb31cbc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -10,6 +10,10 @@ #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -24,6 +28,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { // BatchNormalization opset 6- has unsupported attributes int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } + + public: + bool SupportsMLProgram() const override { return true; } }; void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -50,21 +57,46 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu const auto eps = helper.Get("epsilon", 1e-5f); const auto channels = scale_tensor.dims()[0]; - auto* coreml_batch_norm = layer->mutable_batchnorm(); - coreml_batch_norm->set_channels(channels); - coreml_batch_norm->set_epsilon(eps); - coreml_batch_norm->set_computemeanvar(false); - coreml_batch_norm->set_instancenormalization(false); - - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var - - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm + + std::unique_ptr op = model_builder.CreateOperation(node, "batch_norm"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor)); + AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor)); + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor)); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor)); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_batch_norm = layer->mutable_batchnorm(); + coreml_batch_norm->set_channels(channels); + coreml_batch_norm->set_epsilon(eps); + coreml_batch_norm->set_computemeanvar(false); + coreml_batch_norm->set_instancenormalization(false); + + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } @@ -119,6 +151,15 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu return false; } +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) { + LOGS(logger, VERBOSE) << "float16 input is not supported on the iOS x86_64 simulator" + << " due to CoreML producing invalid output."; + return false; + } +#endif return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 8aa2dbae2531c..0482620b269a4 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -6,6 +6,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -55,6 +56,64 @@ bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger } } // namespace +#if defined(COREML_ENABLE_MLPROGRAM) +static std::vector InferOutputShape(const std::vector& a, const std::vector& b) { + std::vector output_shape; + int64_t i_a = 0, j_b = 0; + if (a.size() >= b.size()) { + output_shape = a; + j_b -= a.size() - b.size(); + } else { + output_shape = b; + i_a -= b.size() - a.size(); + } + + for (size_t i = 0; i < output_shape.size(); i++, i_a++, j_b++) { + const int64_t a_dim = (i_a >= 0) ? a[i_a] : 1; + const int64_t b_dim = (j_b >= 0) ? b[j_b] : 1; + if (a_dim == -1 || b_dim == -1) { + output_shape[i] = -1; + } else { + output_shape[i] = std::max(a_dim, b_dim); + } + } + return output_shape; +} + +// Add variadic inputs to the model builder +// in onnx spec, some node allows variadic inputs, such as max(x, y, z, ...) +// while in coreml, maximum op only allows two inputs maximum(x, y) +// the conversion is doing the following: +// max(x, y, z, ...) -> max(max(x, y), z, ...) +static void AddVariadicInputs(std::unique_ptr* op, + ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) { + using namespace CoreML::Specification::MILSpec; + const auto& input_defs(node.InputDefs()); + std::string_view layer_input_name_x = model_builder.GetUniqueName(node, "variadic"); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + const int32_t elem_type = static_cast(input_dtype); + std::vector x0_shape, x1_shape; + GetShape(*input_defs[0], x0_shape, logger); + GetShape(*input_defs[1], x1_shape, logger); + x0_shape = InferOutputShape(x0_shape, x1_shape); + std::unique_ptr op_prev = std::move(*op); + for (size_t i = 2; i < input_defs.size(); i++) { + AddIntermediateOperationOutput(*op_prev, layer_input_name_x, elem_type, x0_shape); + std::unique_ptr op_cur = model_builder.CreateOperation(node, op_prev->type()); + AddOperationInput(*op_cur, "x", layer_input_name_x); + AddOperationInput(*op_cur, "y", input_defs[i]->Name()); + model_builder.AddOperation(std::move(op_prev)); + op_prev = std::move(op_cur); + layer_input_name_x = model_builder.GetUniqueName(node, "variadic"); + GetShape(*input_defs[i], x1_shape, logger); + x0_shape = InferOutputShape(x0_shape, x1_shape); + } + *op = std::move(op_prev); +} +#endif + Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); @@ -70,6 +129,8 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const coreml_op_type = "add"; } else if (op_type == "Mul") { coreml_op_type = "mul"; + } else if (op_type == "Max") { + coreml_op_type = "maximum"; } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { @@ -86,8 +147,11 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); AddOperationInput(*op, "y", input_defs[1]->Name()); + if (input_defs.size() > 2) { + // "max" node may have variadic inputs + AddVariadicInputs(&op, model_builder, node, logger); + } AddOperationOutput(*op, *node.OutputDefs()[0]); - model_builder.AddOperation(std::move(op)); } else #endif // defined (COREML_ENABLE_MLPROGRAM) @@ -157,6 +221,10 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn return false; } + if (node.OpType() == "Max" && !input_params.create_mlprogram) { + return false; + } + return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index fc8879abbefb0..7c7363d4c81ad 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -18,14 +19,62 @@ class CastOpBuilder : public BaseOpBuilder { bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; -Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, - const Node& /* node */, - const logging::Logger& /* logger */) const { - // This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. - // The ArgMax is fused with the Cast node and produces an int32 output. - // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. +Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. +// The ArgMax is fused with the Cast node and produces an int32 output. +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast + + NodeAttrHelper helper(node); + auto cast_to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto::UNDEFINED); + std::string to_dtype = ""; + if (cast_to_type == ONNX_NAMESPACE::TensorProto::INT32 || cast_to_type == ONNX_NAMESPACE::TensorProto::INT64) { + to_dtype = "int32"; + // CoreML doesn't support int64, while ONNX uses int64 for indices and as well as data values. + // We convert the data inputs/outputs between int64 and int32 when calling onnxruntime::coreml::Model::Predict, + // and when adding int64 initializers to the CoreML model. + // CoreML operators can only produce int32 and not int64 values. + // Due to that there should be no actual int64 values inside the CoreML model and we can infer any + // ONNX_NAMESPACE::TensorProto::INT64 values to be int32. + cast_to_type = ONNX_NAMESPACE::TensorProto::INT32; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) { + to_dtype = "fp32"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) { + to_dtype = "fp16"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::BOOL) { + to_dtype = "bool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported cast type: ", cast_to_type); + } + + std::string_view op_type = "cast"; + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (((input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT32) && + to_dtype == "int32") || + cast_to_type == input_dtype) { + op_type = "identity"; + } + + std::unique_ptr op = model_builder.CreateOperation(node, op_type); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + if (op_type == "cast") { + AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype))); + } + AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type); + model_builder.AddOperation(std::move(op)); + } +#endif + return Status::OK(); } @@ -36,6 +85,10 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } + if (input_params.create_mlprogram) { + return true; + } + const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax @@ -67,14 +120,39 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; + const auto& output = *node.OutputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t input_type, output_type; + if (!GetType(input, input_type, logger)) { return false; + } + if (!GetType(output, output_type, logger)) { + return false; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) && + (output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return true; + } else { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; + return false; + } + } +#endif // only support int64 coming from ArgMax (check for ArgMax is done in IsOpSupportedImpl()) if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index bc9e2f10296ed..f7046c213a8cb 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -98,26 +98,24 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const bool min_max_attribs = node.SinceVersion() < 11; std::string_view min_name; if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) - : node.InputDefs()[1]->Name(); + min_name = (min_max_attribs || !has_min) ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); } else { - min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) - : node.InputDefs()[1]->Name(); + min_name = (min_max_attribs || !has_min) ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) + : node.InputDefs()[1]->Name(); } AddOperationInput(clip_op, "alpha", min_name); - if (has_max) { - std::string_view max_name; - if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) - : node.InputDefs()[2]->Name(); - } else { - max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) - : node.InputDefs()[2]->Name(); - } - AddOperationInput(clip_op, "beta", max_name); + std::string_view max_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + max_name = (min_max_attribs || !has_max) ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + } else { + max_name = (min_max_attribs || !has_max) ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) + : node.InputDefs()[2]->Name(); } + AddOperationInput(clip_op, "beta", max_name); } } @@ -200,7 +198,9 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { float min, max; - return GetClipMinMax(input_params.graph_viewer, node, min, max, logger); + bool ret = GetClipMinMax(input_params.graph_viewer, node, min, max, logger); + // what does it mean if min == max? + return ret && (min != max); } void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index ddaa19c7fab18..fec14dfd093a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -145,6 +145,20 @@ bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderI LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape"; return false; } + + if (mode == "DCR" && input_params.coreml_version < 7) { + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + GetType(*input_defs[0], input_type, logger); + + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + // In CoreML version 6 (e.g., on an iOS 16 simulator) with DCR mode and float16 input, the output is all zeros + // in this unit test: TensorOpTest/1.DepthToSpaceTest_4. + // However, CoreML version 7 is fine. + // Don't support CoreML version < 7, DCR mode, and float16 input. + LOGS(logger, VERBOSE) << "DepthToSpace: DCR mode with float16 input requires at least CoreML version 7."; + return false; + } + } } else { if (mode != "DCR") { LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported"; diff --git a/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc new file mode 100644 index 0000000000000..b4dc8d1647ad0 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/optimizer/initializer.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" +#include + +namespace onnxruntime { +namespace coreml { + +class NormalizationOpBuilder : public BaseOpBuilder { + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; + Status AddGroupNormToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const; + + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; } + + public: + bool SupportsMLProgram() const override { return true; } +}; + +void NormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip everything except input0 for Normalization + const auto& input_defs = node.InputDefs(); + model_builder.AddInitializerToSkip(input_defs[1]->Name()); // scale + if (input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); // B + } +} + +Status NormalizationOpBuilder::AddToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { + if (node.OpType() == "GroupNormalization") { + return AddGroupNormToModelBuilderImpl(model_builder, node, logger); + } +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + + std::vector input_shape; + // GetShape will never fail as we have already verified the input shape in IsOpSupportedImpl + GetShape(*input_defs[0], input_shape, logger); + + const auto rank = input_shape.size(); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + + std::vector axes(rank - axis); + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + std::string_view op_name = (node.OpType() == "InstanceNormalization") ? "instance_norm" : "layer_norm"; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + + std::unique_ptr op = model_builder.CreateOperation(node, op_name); + AddOperationInput(*op, "x", layer_input_name_x); + if (op_name == "layer_norm") { + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), input_defs[0]->Name() + "axes", axes)); + } + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name() + "gamma", scale_tensor)); + if (input_defs.size() > 2) { + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name() + "beta", bias_tensor)); + } + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + + return Status::OK(); +} + +Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + // Coreml hasn't supported GroupNorm yet. + // we decompose GroupNorm to sub ops and levrage LayerNorm to implement GroupNorm. + // groupnorm --> reshape [b, num_groups, c // (num_groups), h, w] --> layer_norm --> reshape [b, c, h, w]->mul(scale)->add(bias) + + // scale and bias is required for group-norm by the onnx spec + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + int64_t num_groups = helper.Get("num_groups", 1); // GroupNorm + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + + const auto input_size = input_shape.size(); + int64_t axis = 2; + std::vector axes(input_size + 1 - axis); // Group add one more dim + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int64_t channel_dims = input_shape[1]; + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + const int32_t elem_type = static_cast(input_dtype); + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + // https://github.com/apple/coremltools/blob/9827d424b3c5b5fbb6ddc8891a000d87a188c84f/coremltools/converters/mil/frontend/torch/ops.py#L1354 + // reshape to [b, num_groups, c // (num_groups), h, w] + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + std::vector shape1 = input_shape; + shape1.insert(shape1.begin() + 1, num_groups); + shape1[2] = input_shape[1] / num_groups; + std::vector shape_scale_bias(input_shape.size(), 1); + shape_scale_bias[1] = channel_dims; + AddOperationInput(*reshape1, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", shape1)); + layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, shape1); + + std::unique_ptr layer_norm = model_builder.CreateOperation(node, "layer_norm"); + AddOperationInput(*layer_norm, "x", layer_input_name_x); + AddOperationInput(*layer_norm, "axes", model_builder.AddConstant(layer_norm->type(), "axes", axes)); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", eps)); + } + + const auto& ln_output_name = model_builder.GetUniqueName(node, "ln_output_"); + AddIntermediateOperationOutput(*layer_norm, ln_output_name, elem_type, shape1); + + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + AddOperationInput(*reshape2, "x", ln_output_name); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", input_shape)); + + const auto& reshape2_output_name = model_builder.GetUniqueName(node, "gn_reshape_output_"); + AddIntermediateOperationOutput(*reshape2, reshape2_output_name, elem_type, input_shape); + + auto mul = model_builder.CreateOperation(node, "mul", "post_mul"); + AddOperationInput(*mul, "x", reshape2_output_name); + AddOperationInput(*mul, "y", model_builder.AddConstant(mul->type(), "mul1", scale_tensor, shape_scale_bias)); + const auto& mul_output_name = model_builder.GetUniqueName(node, "mul_output_"); + AddIntermediateOperationOutput(*mul, mul_output_name, elem_type, input_shape); + + auto add = model_builder.CreateOperation(node, "add", "post_add"); + AddOperationInput(*add, "x", mul_output_name); + AddOperationInput(*add, "y", model_builder.AddConstant(add->type(), "add1", bias_tensor, shape_scale_bias)); + AddOperationOutput(*add, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape1)); + model_builder.AddOperation(std::move(layer_norm)); + model_builder.AddOperation(std::move(reshape2)); + model_builder.AddOperation(std::move(mul)); + model_builder.AddOperation(std::move(add)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + return Status::OK(); +} + +bool NormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // LayerNormalization may have three output in the training mode, but we only support the inference mode + // for InstanceNormalization and GroupNormalization, they only have one output, so this check will always return true + if (node.OutputDefs().size() != 1) { + LOGS(logger, VERBOSE) << "Your onnx model (with LayerNormalization) may be in training mode," + << " please export it for inferencing."; + return false; + } + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + + // groupnorm and layernorm has attribute "stash_type", while InstanceNormalization doesn't have this attribute + // Type of Mean and InvStdDev. This also specifies stage one’s computation precision. + // if stash_type is 1, this operator casts all input variables to 32-bit float, + // perform the computation, and finally cast Normalized back to the original type of X + // coreml didn't have a similiar attribute to stash_type, for now, we support the default value + if (node.OpType() != "InstanceNormalization") { + NodeAttrHelper helper(node); + const auto stash_type = helper.Get("stash_type", 1); + if (stash_type != 1) { + LOGS(logger, VERBOSE) << "stash_type != 1 is not supported"; + return false; + } + } + + const auto& scale_name = input_defs[1]->Name(); + const auto* scale_tensor = input_params.graph_viewer.GetConstantInitializer(scale_name); + if (!scale_tensor) { + LOGS(logger, VERBOSE) << "Scale must be a constant initializer"; + return false; + } + + if (input_defs.size() > 2) { + const auto& b_name = input_defs[2]->Name(); + const auto& b_tensor = input_params.graph_viewer.GetConstantInitializer(b_name); + if (!b_tensor) { + LOGS(logger, VERBOSE) << "Bias must be a constant initializer"; + return false; + } + } + + return true; +} + +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (!input_params.create_mlprogram) { + return false; + } + // We only check the type of input 0,1,2 + const auto& input_0 = *node.InputDefs()[0]; + const auto& input_1 = *node.InputDefs()[1]; + const auto& input_2 = node.InputDefs().size() > 2 ? *node.InputDefs()[2] : input_0; + int32_t input_type_0, input_type_1, input_type_2; + if (!GetType(input_0, input_type_0, logger)) { + return false; + } + if (!GetType(input_1, input_type_1, logger)) { + return false; + } + if (!GetType(input_2, input_type_2, logger)) { + return false; + } + if (input_type_0 != input_type_1 || input_type_0 != input_type_2) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be the same"; + return false; + } + + if (input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be float or float16"; + return false; + } + return true; +} + +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 5651b9cc5793e..d533b867bd454 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -5,10 +5,15 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -20,6 +25,7 @@ class ReductionOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; namespace { @@ -48,13 +54,12 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - const auto& initializers(model_builder.GetInitializerTensors()); std::vector axes; NodeAttrHelper helper(node); if (input_defs.size() > 1 && input_defs[1]->Exists()) { - auto& axes_tensor = *initializers.at(input_defs[1]->Name()); + auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); Initializer axes_initializer(axes_tensor); int64_t* data = axes_initializer.data(); int64_t size = axes_initializer.size(); @@ -66,28 +71,76 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::string_view coreml_op_type; + if (noop_with_empty_axes && axes.size() == 0) { + coreml_op_type = "identity"; + } else if (op_type == "ReduceSum") { + coreml_op_type = "reduce_sum"; + } else if (op_type == "ReduceMean") { + coreml_op_type = "reduce_mean"; + } else if (op_type == "ReduceMax") { + coreml_op_type = "reduce_max"; + } else if (op_type == "ReduceMin") { + coreml_op_type = "reduce_min"; + } else if (op_type == "ReduceProd") { + coreml_op_type = "reduce_prod"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ReductionOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); + } + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (coreml_op_type != "identity") { + if (axes.size() > 0) { + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes)); + } + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", keepdims)); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "ReduceSum") { + AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); + } else if (op_type == "ReduceMean") { + AddReductionParams(layer->mutable_reducemean(), axes, keepdims, noop_with_empty_axes); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ReductionOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } - std::unique_ptr layer = model_builder.CreateNNLayer(node); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - if (op_type == "ReduceSum") { - AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); - } else if (op_type == "ReduceMean") { - AddReductionParams(layer->mutable_reducemean(), axes, keepdims, noop_with_empty_axes); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ReductionOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + model_builder.AddLayer(std::move(layer)); } - - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); return Status::OK(); } bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (!input_params.create_mlprogram && + (node.OpType() == "ReduceMax" || node.OpType() == "ReduceMin" || node.OpType() == "ReduceProd")) { + return false; + } + +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros + int32_t input_type; + GetType(*input_defs[0], input_type, logger); + if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + return false; + } +#endif NodeAttrHelper helper(node); @@ -99,18 +152,16 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu if (input_defs.size() > 1 && input_defs[1]->Exists()) { // 'axes' is optional input in new opsets const auto& axes_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, axes_name)) { + const auto* axes = input_params.graph_viewer.GetConstantInitializer(axes_name); + if (!axes) { LOGS(logger, VERBOSE) << "Axes of reduction must be a constant initializer"; return false; } - empty_axes = initializers.at(axes_name)->int64_data_size() == 0; + empty_axes = axes->int64_data_size() == 0; } - - if (empty_axes && noop_with_empty_axes) { - // TODO: When we add ML Program support we should enable this as it makes the node an Identity op - LOGS(logger, VERBOSE) << "CoreML doesn't support noop on empty axes for reduction layers" << std::endl; + if (empty_axes && noop_with_empty_axes && !input_params.create_mlprogram) { + LOGS(logger, VERBOSE) << "NeuralNetwork doesn't support noop on empty axes for reduction layers"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index a86e3d9538d87..243f949bdd48e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,7 +2,9 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper @@ -14,28 +16,132 @@ class ShapeOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /*logger*/) const { - auto layer = model_builder.CreateNNLayer(node); - layer->mutable_getshape(); - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + const auto& input_defs = node.InputDefs(); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + NodeAttrHelper node_attr_helper{node}; + int64_t size = -1; + int64_t num_dims = 0; + int64_t start = node_attr_helper.Get("start", 0); + // If the input shape is not available, size is -1 and start is 0 + if (input_defs[0]->Shape()) { + num_dims = input_defs[0]->Shape()->dim_size(); + start = HandleNegativeAxis(start, num_dims); + if (node_attr_helper.HasAttr("end")) { + int64_t end = HandleNegativeAxis(node_attr_helper.Get("end", -1), num_dims); + size = end - start; + } + } + + int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; + std::unique_ptr op = model_builder.CreateOperation(node, "shape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (size != -1 || start != 0) { + std::string_view layer_input_name_x = model_builder.GetUniqueName(node, "slice_by_size"); + std::vector x0_shape{num_dims}; + AddIntermediateOperationOutput(*op, layer_input_name_x, output_datatype, x0_shape); + model_builder.AddOperation(std::move(op)); + + auto slice_op = model_builder.CreateOperation(node, "slice_by_size"); + AddOperationInput(*slice_op, "x", layer_input_name_x); + std::vector starts = {start}; + std::vector sizes = {size}; + AddOperationInput(*slice_op, "begin", model_builder.AddConstant(slice_op->type(), "begin", starts)); + AddOperationInput(*slice_op, "size", model_builder.AddConstant(slice_op->type(), "size", sizes)); + AddOperationOutput(*slice_op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(slice_op)); + } else { + AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(op)); + } + } else // NOLINT +#endif + { + auto layer = model_builder.CreateNNLayer(node); + layer->mutable_getshape(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { + const auto* tensor_shape = node.InputDefs()[0]->Shape(); + NodeAttrHelper node_attr_helper{node}; - if (node_attr_helper.Get("start", 0) != 0) { - LOGS(logger, VERBOSE) << "Shape does not support 'start' attribute with value other than 0"; + if (!input_params.create_mlprogram) { + if (node_attr_helper.HasAttr("end")) { + LOGS(logger, VERBOSE) << "Shape does not support 'end' attribute"; + return false; + } + + if (node_attr_helper.Get("start", 0) != 0) { + LOGS(logger, VERBOSE) << "Shape does not support 'start' attribute with value other than 0"; + return false; + } + } else { + int64_t end = node_attr_helper.HasAttr("end") + ? node_attr_helper.Get("end", -1) + : std::numeric_limits::max(); + int64_t start = node_attr_helper.Get("start", 0); + // no need to slice if start is 0 and end is max + if (end == std::numeric_limits::max() && start == 0) { + } else if (tensor_shape == nullptr) { + LOGS(logger, VERBOSE) << "Shape does not support slicing when tensor_shape is not available"; + return false; + } + int64_t dim_size = tensor_shape->dim_size(); + int64_t size = node_attr_helper.HasAttr("end") + ? HandleNegativeAxis(node_attr_helper.Get("end", -1), dim_size) + : dim_size; + start = HandleNegativeAxis(start, dim_size); + size = size - start; + if (size == 0) { + LOGS(logger, VERBOSE) << "Shape does not support slicing when size is 0"; + return false; + } + } + + return true; +} + +bool ShapeOpBuilder::HasSupportedInputsImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 + const auto& input = *node.InputDefs()[0]; + + int32_t input_type; + if (!GetType(input, input_type, logger)) { return false; } - if (node_attr_helper.HasAttr("end")) { - LOGS(logger, VERBOSE) << "Shape does not support 'end' attribute"; + if (input_params.create_mlprogram) { + if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return true; + } else { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; + return false; + } + } else if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index d6584124c6aba..c6e331feed326 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -4,6 +4,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,6 +19,7 @@ class SoftmaxOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -33,55 +35,100 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); int32_t axis_default_value = (node.SinceVersion() < 13) ? 1 : -1; const auto axis = helper.Get("axis", axis_default_value); - const auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); - - if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { - auto* coreml_softmaxnd = layer->mutable_softmaxnd(); - coreml_softmaxnd->set_axis(axis); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); - } else { - // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. - // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. - TensorShape input_shape(data_shape); - const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); - const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); - - TensorShapeVector target_shape; - target_shape.push_back(size_to_dimension); - target_shape.push_back(size_from_dimension); - - const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); - { // Add reshape layer - auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); - *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; - *reshape_layer->mutable_input()->Add() = input_name; - *reshape_layer->mutable_output()->Add() = reshape1_output_name; - model_builder.AddLayer(std::move(reshape_layer)); + auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); + +#if defined(COREML_ENABLE_MLPROGRAM) + // CoreML's softmax match onnx's softmax behavior since opset 13. + // For opset < 13, we need to reshape to 2D and set axis to -1 to simulate onnx softmax behavior. + // [B,D,...](onnx softmax opset 12, axis=1)->[B,D*...](CoreML softmax, axis=-1)->[B,D,...](reshape back) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + const int32_t elem_type = static_cast(input_dtype); + + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + const bool need_reshape = node.SinceVersion() < 13 && axis_nonnegative != static_cast(data_shape.size()) - 1; + std::vector target_shape; + if (need_reshape) { + // reshape to 2D to simulate onnx softmax behavior + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + TensorShape input_shape(data_shape); + target_shape.push_back(input_shape.SizeToDimension(axis_nonnegative)); + target_shape.push_back(input_shape.SizeFromDimension(axis_nonnegative)); + axis_nonnegative = 1; + AddOperationInput(*reshape1, "x", layer_input_name_x); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", target_shape)); + layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, target_shape); + model_builder.AddOperation(std::move(reshape1)); } - const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); - { + std::unique_ptr op = model_builder.CreateOperation(node, "softmax"); + AddOperationInput(*op, "x", layer_input_name_x); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis_nonnegative)); + if (!need_reshape) { + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else { + std::string_view ln_output_name = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*op, ln_output_name, elem_type, target_shape); + model_builder.AddOperation(std::move(op)); + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + AddOperationInput(*reshape2, "x", ln_output_name); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", data_shape)); + AddOperationOutput(*reshape2, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(reshape2)); + } + } else // NOLINT +#endif + { + if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); - coreml_softmaxnd->set_axis(-1); - *layer->mutable_input()->Add() = reshape1_output_name; - *layer->mutable_output()->Add() = softmax_output_name; + coreml_softmaxnd->set_axis(axis); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; model_builder.AddLayer(std::move(layer)); - } - { - // Add reshape back layer - auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); - *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; - *reshape_layer->mutable_input()->Add() = softmax_output_name; - *reshape_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(reshape_layer)); + } else { + // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. + // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. + TensorShape input_shape(data_shape); + const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); + const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); + + TensorShapeVector target_shape; + target_shape.push_back(size_to_dimension); + target_shape.push_back(size_from_dimension); + + const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); + { // Add reshape layer + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; + *reshape_layer->mutable_input()->Add() = input_name; + *reshape_layer->mutable_output()->Add() = reshape1_output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); + { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(-1); + *layer->mutable_input()->Add() = reshape1_output_name; + *layer->mutable_output()->Add() = softmax_output_name; + model_builder.AddLayer(std::move(layer)); + } + { + // Add reshape back layer + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; + *reshape_layer->mutable_input()->Add() = softmax_output_name; + *reshape_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } } } return Status::OK(); } -bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index dbd0f48576f8b..6372f3136123b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -51,8 +51,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) { // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; - uint64_t remainder = split_dim_size % chunk_size; + int64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; + int64_t remainder = split_dim_size % chunk_size; return std::make_tuple(remainder, chunk_size); }; @@ -106,20 +106,20 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // if "split" is explicitly provided as an input // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); - auto split_span = unpacked_tensor.DataAsSpan(); + auto split_span = unpacked_tensor.DataAsSpan(); for (const auto& split_size : split_span) { coreml_splitnd->add_splitsizes(split_size); } } else if (node.SinceVersion() < 18) { - uint64_t num_outputs = narrow(node.OutputDefs().size()); + int64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - uint64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); + int64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); if (remainder) { // uneven - auto split_sizes = InlinedVector(num_outputs, chunk_size); + auto split_sizes = InlinedVector(num_outputs, chunk_size); split_sizes.back() = remainder; for (size_t i = 0; i < split_sizes.size(); i++) { coreml_splitnd->add_splitsizes(split_sizes[i]); @@ -162,7 +162,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } const auto split_shape = *input_defs[1]->Shape(); - if (split_shape.dim_size() < 2) { + if (split_shape.dim(0).dim_value() < 2) { LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index e9cc1c2dbf638..a1b3a18265c70 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -5,10 +5,17 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" +#include "core/providers/cpu/tensor/unsqueeze.h" + +#ifdef __APPLE__ +#include +#endif namespace onnxruntime { namespace coreml { @@ -21,16 +28,16 @@ class SqueezeOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; namespace { -Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +void GetAxes(ModelBuilder& model_builder, const Node& node, TensorShapeVector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all if (node.InputDefs().size() > 1) { - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& axes_tensor = *initializers.at(node.InputDefs()[1]->Name()); + const auto& axes_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name()); Initializer unpacked_tensor(axes_tensor); auto raw_axes = unpacked_tensor.DataAsSpan(); const auto size = SafeInt(axes_tensor.dims()[0]); @@ -39,10 +46,9 @@ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector()); + auto axes_attr = helper.Get("axes", std::vector()); + axes.assign(axes_attr.begin(), axes_attr.end()); } - - return Status::OK(); } } // namespace @@ -52,40 +58,103 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } +#if defined(COREML_ENABLE_MLPROGRAM) +void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, + const Node& node, const logging::Logger& logger) { + const auto& input_defs(node.InputDefs()); + TensorShapeVector axes; + GetAxes(model_builder, node, axes); + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + auto op = model_builder.CreateOperation(node, "reshape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); + AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); +} +#endif + Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + [[maybe_unused]] const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_squeeze = layer->mutable_squeeze(); - std::vector axes; - ORT_RETURN_IF_ERROR(GetAxes(model_builder, node, axes)); - if (axes.empty()) { - coreml_squeeze->set_squeezeall(true); - } else { - *coreml_squeeze->mutable_axes() = {axes.cbegin(), axes.cend()}; - coreml_squeeze->set_squeezeall(false); - } + TensorShapeVector axes; + GetAxes(model_builder, node, axes); +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs(node.InputDefs()); + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + +#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64 + // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input + if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) { + HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger); + return Status::OK(); + } +#endif + std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims"; + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + + if (!axes.empty()) { + // coreml supports negative axes + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif + { + if (axes.empty()) { + coreml_squeeze->set_squeezeall(true); + } else { + *coreml_squeeze->mutable_axes() = {axes.cbegin(), axes.cend()}; + coreml_squeeze->set_squeezeall(false); + } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& /*logger*/) const { + const logging::Logger& logger) const { // Squeeze opset 13 uses input 1 as axes, if we have input 1 then it needs to be an initializer - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - const auto& axes_name = node.InputDefs()[1]->Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be known"; + const auto& input_defs = node.InputDefs(); + if (node.SinceVersion() > 12 && input_defs.size() > 1) { + const auto& axes_name = input_defs[1]->Name(); + if (!input_params.graph_viewer.GetConstantInitializer(axes_name)) { + LOGS(logger, VERBOSE) << "Input axes must be known"; return false; } } + if (node.OpType() == "Unsqueeze") { + if (!input_params.create_mlprogram) { + return false; + } + + int64_t num_of_new_dims = 0; + if (node.SinceVersion() > 12) { + num_of_new_dims = node.InputDefs()[1]->Shape()->dim(0).dim_value(); + } else { + NodeAttrHelper helper(node); + auto axes = helper.Get("axes", std::vector()); + num_of_new_dims = static_cast(axes.size()); + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger) || input_shape.size() + num_of_new_dims > 5) { + LOGS(logger, VERBOSE) << "Unsqueeze to output shape with > 5 dimensions is not supported"; + return false; + } + } return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index a6580920343c4..bc3cad004aec1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -16,6 +16,8 @@ class UnaryOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; bool SupportsMLProgram() const override { return true; } + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; }; Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -32,6 +34,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const coreml_op_type = "sqrt"; } else if (op_type == "Reciprocal") { coreml_op_type = "inverse"; + } else if (op_type == "Erf") { + coreml_op_type = "erf"; + } else if (op_type == "Round") { + coreml_op_type = "round"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); @@ -74,6 +80,14 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } +bool UnaryOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& /*logger*/) const { + if (!input_params.create_mlprogram && (node.OpType() == "Erf" || node.OpType() == "Round")) { + return false; + } + return true; +} + void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 50faebf06875d..6486942199df7 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -8,12 +8,14 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/coreml_execution_provider.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/optimizer/initializer.h" #if defined(COREML_ENABLE_MLPROGRAM) // includes from coremltools-src in _deps @@ -400,14 +402,14 @@ std::string GetModelOutputPath(bool create_ml_program) { } // namespace ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names) : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), - coreml_flags_(coreml_flags), - create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + coreml_options_(coreml_options), + create_ml_program_(coreml_options.CreateMLProgram()), model_output_path_(GetModelOutputPath(create_ml_program_)), onnx_input_names_(std::move(onnx_input_names)), onnx_output_names_(std::move(onnx_output_names)), @@ -987,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { get_sanitized_io_info(std::move(input_output_info_)), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_flags_); + logger_, coreml_options_); } else #endif { @@ -997,19 +999,61 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { std::move(input_output_info_), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_flags_); + logger_, coreml_options_); } return model->LoadModel(); // load using CoreML API, including compilation } +#if defined(COREML_ENABLE_MLPROGRAM) +std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape) { + const auto data_type = tensor.data_type(); + Initializer unpacked_tensor(tensor); + std::string_view ret; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + default: + ORT_THROW("AddConstant: Unsupported data type: ", data_type); + } + + return ret; +} +#endif // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names, std::unique_ptr& model) { - ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags, + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options, std::move(onnx_input_names), std::move(onnx_output_names)); ORT_RETURN_IF_ERROR(builder.CreateModel()); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index b3dfec29872a2..e19597cf0dc2e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -7,6 +7,7 @@ #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/model/model.h" +#include "core/providers/coreml/coreml_options.h" #if defined(COREML_ENABLE_MLPROGRAM) // coremltools classes @@ -29,14 +30,14 @@ class IOpBuilder; class ModelBuilder { private: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names); public: // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names, std::unique_ptr& model); @@ -129,6 +130,12 @@ class ModelBuilder { return AddConstant(op_type, value_type, gsl::span(value), shape); } + // helper to convert a initializer to a constant + // by default, shape is inferred from the tensor.dims(), but can be provided to override if needed + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape = std::nullopt); + /// /// Add a scalar value as a 'const' operation. See AddConstant for details. /// @@ -210,7 +217,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const int32_t coreml_version_; - const uint32_t coreml_flags_; + CoreMLOptions coreml_options_; const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index b0006b24e7d75..6e7df20a06097 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -21,15 +21,19 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateActivationOpBuilder("Relu", op_registrations); CreateActivationOpBuilder("PRelu", op_registrations); CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Gelu", op_registrations); // Unary ops + CreateUnaryOpBuilder("Erf", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Round", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); // Binary elementwise ops CreateBinaryOpBuilder("Add", op_registrations); CreateBinaryOpBuilder("Div", op_registrations); CreateBinaryOpBuilder("Mul", op_registrations); + CreateBinaryOpBuilder("Max", op_registrations); CreateBinaryOpBuilder("Pow", op_registrations); CreateBinaryOpBuilder("Sub", op_registrations); @@ -41,10 +45,18 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { // Reduction ops CreateReductionOpBuilder("ReduceMean", op_registrations); + CreateReductionOpBuilder("ReduceMin", op_registrations); + CreateReductionOpBuilder("ReduceMax", op_registrations); + CreateReductionOpBuilder("ReduceProd", op_registrations); CreateReductionOpBuilder("ReduceSum", op_registrations); - CreateArgMaxOpBuilder("ArgMax", op_registrations); + // Normalization ops CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations); + CreateNormalizationOpBuilder("GroupNormalization", op_registrations); + CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); + CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + + CreateArgMaxOpBuilder("ArgMax", op_registrations); CreateCastOpBuilder("Cast", op_registrations); CreateClipOpBuilder("Clip", op_registrations); CreateConcatOpBuilder("Concat", op_registrations); @@ -66,6 +78,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSoftmaxOpBuilder("Softmax", op_registrations); CreateSqueezeOpBuilder("Squeeze", op_registrations); CreateTransposeOpBuilder("Transpose", op_registrations); + CreateSqueezeOpBuilder("Unsqueeze", op_registrations); return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 1990fb6400ce1..9b51b53d73e9e 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -19,6 +19,7 @@ const std::unordered_map& GetOpBuilders(); void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index f2cd4d01174d3..5a2867e5524e4 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -23,27 +23,14 @@ namespace onnxruntime { constexpr const char* COREML = "CoreML"; -CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) +CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags), + coreml_options_(options), coreml_version_(coreml::util::CoreMLVersion()) { LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_; if (coreml_version_ < MINIMUM_COREML_VERSION) { - LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + ORT_THROW("CoreML EP is not supported on this platform."); } - -#if defined(COREML_ENABLE_MLPROGRAM) - if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && - (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; - coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; - } -#else - if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; - coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; - } -#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -53,26 +40,17 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - if (coreml_version_ < MINIMUM_COREML_VERSION) { - return result; - } - const auto& logger = *GetLogger(); // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { - return result; - } - - const bool has_neural_engine = coreml::HasNeuralEngine(logger); - if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; + if (graph_viewer.IsSubgraph() && !coreml_options_.EnableOnSubgraph()) { return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, + coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram()); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); const auto gen_metadef_name = @@ -135,7 +113,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector onnx_output_names = get_names(fused_node.OutputDefs()); const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_options_, std::move(onnx_input_names), std::move(onnx_output_names), coreml_model)); } @@ -210,7 +188,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector lock(model->GetMutex()); + std::unique_lock lock(model->GetMutex()); std::unordered_map outputs; outputs.reserve(model_outputs.size()); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 24a001280eef5..650d81a4fecf7 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/inlined_containers.h" +#include "core/providers/coreml/coreml_options.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" @@ -14,7 +14,7 @@ class Model; class CoreMLExecutionProvider : public IExecutionProvider { public: - CoreMLExecutionProvider(uint32_t coreml_flags); + CoreMLExecutionProvider(const CoreMLOptions& options); virtual ~CoreMLExecutionProvider(); std::vector> @@ -29,7 +29,7 @@ class CoreMLExecutionProvider : public IExecutionProvider { private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - uint32_t coreml_flags_; + CoreMLOptions coreml_options_; const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc new file mode 100644 index 0000000000000..4ec780208e528 --- /dev/null +++ b/onnxruntime/core/providers/coreml/coreml_options.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags +#include "core/providers/coreml/model/host_utils.h" +#include "core/providers/coreml/builders/helper.h" + +namespace onnxruntime { + +CoreMLOptions::CoreMLOptions(uint32_t coreml_flags) { + // validate the flags and populate the members. should be moving code from ctor to here + require_static_shape_ = (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0; + create_mlprogram_ = (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0; + enable_on_subgraph_ = (coreml_flags & COREML_FLAG_ENABLE_ON_SUBGRAPH) != 0; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml::util::CoreMLVersion() < MINIMUM_COREML_MLPROGRAM_VERSION && create_mlprogram_ != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + create_mlprogram_ = false; + } +#else + if (create_mlprogram_ != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + create_mlprogram_ = false; + } +#endif + + compute_units_ = 0; // 0 for all + + if (coreml_flags & COREML_FLAG_USE_CPU_ONLY) { + compute_units_ |= COREML_FLAG_USE_CPU_ONLY; + } + if (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU) { + compute_units_ |= COREML_FLAG_USE_CPU_AND_GPU; + } + if (coreml_flags & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + compute_units_ |= COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE; + } + + // assure only one device option is selected + if (compute_units_ & (compute_units_ - 1)) { + // multiple device options selected + ORT_THROW( + "Multiple device options selected, you should use at most one of the following options:" + "[COREML_FLAG_USE_CPU_ONLY, COREML_FLAG_USE_CPU_AND_GPU, COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE]"); + } + + const bool has_neural_engine = coreml::HasNeuralEngine(); + if (ComputeUnits(COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { + ORT_THROW("The current system does not have Apple Neural Engine."); + } +} + +void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& options) { + const std::unordered_map available_computeunits_options = { + {"CPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE}, + {"CPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU}, + {"CPUOnly", COREML_FLAG_USE_CPU_ONLY}, + {"ALL", COREML_FLAG_USE_NONE}, + }; + const std::unordered_map available_modelformat_options = { + {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM}, + {"NeuralNetwork", COREML_FLAG_USE_NONE}, + }; + const std::unordered_set valid_options = { + kCoremlProviderOption_MLComputeUnits, + kCoremlProviderOption_ModelFormat, + kCoremlProviderOption_RequireStaticInputShapes, + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU, + }; + // Validate the options + for (const auto& option : options) { + if (valid_options.find(option.first) == valid_options.end()) { + ORT_THROW("Unknown option: ", option.first); + } + if (kCoremlProviderOption_MLComputeUnits == option.first) { + if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) { + ORT_THROW("Invalid value for option `", option.first, "`: ", option.second); + } else { + compute_units_ = available_computeunits_options.at(option.second); + } + } else if (kCoremlProviderOption_ModelFormat == option.first) { + if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second); + } else { + create_mlprogram_ = available_modelformat_options.at(option.second) & COREML_FLAG_CREATE_MLPROGRAM; + } + } else if (kCoremlProviderOption_RequireStaticInputShapes == option.first) { + require_static_shape_ = option.second == "1"; + } else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) { + enable_on_subgraph_ = option.second == "1"; + } else if (kCoremlProviderOption_SpecializationStrategy == option.first) { + if (option.second != "Default" && option.second != "FastPrediction") { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second, + ". Valid values are Default and FastPrediction."); + } + strategy_ = option.second; + } else if (kCoremlProviderOption_ProfileComputePlan == option.first) { + profile_compute_plan_ = option.second == "1"; + } else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) { + allow_low_precision_accumulation_on_gpu_ = option.second == "1"; + } + } +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h new file mode 100644 index 0000000000000..fd05c96927bd1 --- /dev/null +++ b/onnxruntime/core/providers/coreml/coreml_options.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +class CoreMLOptions { + private: + bool require_static_shape_{false}; + bool create_mlprogram_{false}; + bool enable_on_subgraph_{false}; + uint32_t compute_units_{0}; + std::string strategy_; + bool profile_compute_plan_{false}; + bool allow_low_precision_accumulation_on_gpu_{false}; + + public: + explicit CoreMLOptions(uint32_t coreml_flags); + + CoreMLOptions(const ProviderOptions& options) { + ValidateAndParseProviderOption(options); + } + bool RequireStaticShape() const { return require_static_shape_; } + bool CreateMLProgram() const { return create_mlprogram_; } + bool EnableOnSubgraph() const { return enable_on_subgraph_; } + uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; } + bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; } + bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; } + bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; } + + private: + void ValidateAndParseProviderOption(const ProviderOptions& options); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc index fcdf37c446ce7..bc8702d3290f6 100644 --- a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc +++ b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc @@ -9,21 +9,28 @@ using namespace onnxruntime; namespace onnxruntime { + struct CoreMLProviderFactory : IExecutionProviderFactory { - CoreMLProviderFactory(uint32_t coreml_flags) - : coreml_flags_(coreml_flags) {} + CoreMLProviderFactory(const CoreMLOptions& options) + : options_(options) {} ~CoreMLProviderFactory() override {} std::unique_ptr CreateProvider() override; - uint32_t coreml_flags_; + CoreMLOptions options_; }; std::unique_ptr CoreMLProviderFactory::CreateProvider() { - return std::make_unique(coreml_flags_); + return std::make_unique(options_); } std::shared_ptr CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) { - return std::make_shared(coreml_flags); + CoreMLOptions coreml_options(coreml_flags); + return std::make_shared(coreml_options); +} + +std::shared_ptr CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) { + CoreMLOptions coreml_options(options); + return std::make_shared(coreml_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h index ba701724c4da9..93ec2af50698d 100644 --- a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h +++ b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h @@ -5,10 +5,12 @@ #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" namespace onnxruntime { struct CoreMLProviderFactoryCreator { static std::shared_ptr Create(uint32_t coreml_flags); + static std::shared_ptr Create(const ProviderOptions& options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index a9991ccb945ce..145c64e5320d3 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -26,6 +26,8 @@ // - iOS 16 ops // 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) // - iOS 17 ops +// 9 : iOS 18, macOS 15, tvOS 18, watchOS 11 (Core ML 8) +// - iOS 18 ops // // **NOTE** We use the Core ML version not the spec version. // @@ -39,6 +41,7 @@ #define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) #define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) #define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) +#define API_AVAILABLE_COREML8 API_AVAILABLE(macos(15), ios(18)) // @available is used in implementation code // Base required OS to run CoreML Specification Version 4 (Core ML 3) @@ -47,6 +50,7 @@ #define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) #define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) #define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) +#define HAS_COREML8_OR_LATER @available(macOS 15, iOS 18, *) #endif diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 70052f50ae1c2..4239121a42c97 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -16,6 +16,8 @@ bool HasRequiredBaseOS() { } int32_t CoreMLVersion() { + if (HAS_COREML8_OR_LATER) + return 8; if (HAS_COREML7_OR_LATER) return 7; if (HAS_COREML6_OR_LATER) diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 75b9aaf2185c9..84b7d741b4714 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -11,13 +11,14 @@ #include #include "core/common/logging/logging.h" #include "core/common/status.h" -#include "core/platform/ort_mutex.h" +#include #if defined(__OBJC__) @class MLMultiArray; #endif namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution; @@ -53,7 +54,7 @@ class Model { std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, - const logging::Logger& logger, uint32_t coreml_flags); + const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); @@ -73,7 +74,7 @@ class Model { } // Mutex for exclusive lock to this model object - OrtMutex& GetMutex() { return mutex_; } + std::mutex& GetMutex() { return mutex_; } // Input and output names in the ORT fused node's order. // Names may have been adjusted from the originals due to CoreML naming rules. @@ -101,7 +102,7 @@ class Model { std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; - OrtMutex mutex_; + std::mutex mutex_; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 5f4eebc7d72ce..755dbfbd6e68c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -25,6 +25,7 @@ #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/providers/coreml/coreml_options.h" // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this @@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, return Status::OK(); } +// since __clang_major__ >= 15, MLComputePlan is introduced in +// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`. +// The macro API_AVAILABLE should also be fine. +// Otherwise, the compiler will complain `MLComputePlan` is not defined. +// we define __clang_analyzer__ here is for bypass static analysis +void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) { +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (@available(macOS 14.4, iOS 17.4, *)) { + [MLComputePlan loadContentsOfURL:compileUrl + configuration:config + completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) { + if (!computePlan) { + NSLog(@"Error loading compute plan: %@", error); + // Handle error. + return; + } + MLModelStructureProgram* program = computePlan.modelStructure.program; + if (!program) { + NSLog(@"Error loading program from compute plan., this is not a mlprogram model"); + return; + } + + MLModelStructureProgramFunction* mainFunction = program.functions[@"main"]; + if (!mainFunction) { + NSLog(@"Error loading main function from program"); + return; + } + + NSArray* operations = mainFunction.block.operations; + NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count); + for (MLModelStructureProgramOperation* operation in operations) { + // Get the compute device usage for the operation. + MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation]; + id preferredDevice = computeDeviceUsage.preferredComputeDevice; + // Get the estimated cost of executing the operation. + MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation]; + if (![operation.operatorName isEqualToString:@"const"]) { + NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight); + } + } + }]; + } else { + NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API"); + } +#endif +} + // Internal Execution class // This class is part of the model class and handles the calls into CoreML. Specifically, it performs // 1. Compile the model by given path for execution @@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, // 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: - Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); + Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Execution(); Status LoadModel(); @@ -320,13 +368,13 @@ Status Predict(const std::unordered_map& inputs, NSString* coreml_model_path_{nil}; NSString* compiled_model_path_{nil}; const logging::Logger& logger_; - uint32_t coreml_flags_{0}; + CoreMLOptions coreml_options_; MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) +Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options) : logger_(logger), - coreml_flags_(coreml_flags) { + coreml_options_(coreml_options) { @autoreleasepool { coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); } @@ -395,9 +443,41 @@ Status Predict(const std::unordered_map& inputs, compiled_model_path_ = [compileUrl path]; MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; - config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) - ? MLComputeUnitsCPUOnly - : MLComputeUnitsAll; + uint32_t coreml_compute_unit = coreml_options_.ComputeUnits(); + if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) { + config.computeUnits = MLComputeUnitsCPUOnly; + } else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) { + config.computeUnits = MLComputeUnitsCPUAndGPU; + } else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine + } else { + config.computeUnits = MLComputeUnitsAll; + } + + if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) { + config.allowLowPrecisionAccumulationOnGPU = YES; + } + +// Set the specialization strategy to FastPrediction for macOS 10.15+ +// since __clang_major__ >= 15, optimizationHints is introduced in +// Same as above comments for why we are checking __clang_major__. +// we define __clang_analyzer__ here is for bypass static analysis +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (HAS_COREML8_OR_LATER) { + MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init]; + if (coreml_options_.UseStrategy("FastPrediction")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction; + config.optimizationHints = optimizationHints; + } else if (coreml_options_.UseStrategy("Default")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyDefault; + config.optimizationHints = optimizationHints; + } + } +#endif + if (coreml_options_.ProfileComputePlan()) { + ProfileComputePlan(compileUrl, config); + } + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; if (error != nil || model_ == nil) { @@ -516,8 +596,8 @@ Status Predict(const std::unordered_map& inputs, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, - uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)), + const CoreMLOptions& coreml_options) + : execution_(std::make_unique(path, logger, coreml_options)), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index c6f2e7401ea1e..e9036e2fc7e1a 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/model/model.h" namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution {}; @@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, - uint32_t /*coreml_flags*/) + const CoreMLOptions& /*coreml_flags*/) : execution_(std::make_unique()), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 424bee63511ad..0499a15e1df0a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2,7 +2,8 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" -#include + +#include "core/framework/allocator_utils.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/int4.h" @@ -30,14 +31,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {} std::vector CPUExecutionProvider::CreatePreferredAllocators() { - bool create_arena = info_.create_arena; -#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) || defined(ABSL_HAVE_ADDRESS_SANITIZER) - // JEMalloc/mimalloc already have memory pool, so just use device allocator. - create_arena = false; -#elif !(defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) - // Disable Arena allocator for x86_32 build because it may run into infinite loop when integer overflow happens - create_arena = false; -#endif + const bool create_arena = DoesCpuAllocatorSupportArenaUsage() ? info_.create_arena : false; AllocatorCreationInfo device_info{[](int) { return std::make_unique(); }, DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena}; @@ -379,8 +373,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, uint8_t, + QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, int8_t, + QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); @@ -1108,6 +1104,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); @@ -1691,10 +1689,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, QuantizeLinear)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, @@ -2923,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder); @@ -3041,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { TreeEnsembleRegressor)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index dfa27f1f44d5a..091b01b81b5b1 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -138,7 +138,7 @@ static TensorProto::DataType InferDataType(const Tensor& tensor); Status RandomNormal::Compute(OpKernelContext* ctx) const { Tensor& Y = *ctx->Output(0, shape_); - std::lock_guard l(generator_mutex_); + std::lock_guard l(generator_mutex_); auto status = RandomNormalCompute(mean_, scale_, generator_, dtype_, Y); return status; @@ -147,7 +147,7 @@ Status RandomNormal::Compute(OpKernelContext* ctx) const { Status RandomUniform::Compute(OpKernelContext* ctx) const { Tensor& Y = *ctx->Output(0, shape_); - std::lock_guard l(generator_mutex_); + std::lock_guard l(generator_mutex_); auto status = RandomUniformCompute(low_, high_, generator_, dtype_, Y); return status; @@ -169,7 +169,7 @@ Status RandomNormalLike::Compute(OpKernelContext* ctx) const { "Could not infer data type from input tensor with data type ", X.DataType()); - std::lock_guard l(generator_mutex_); + std::lock_guard l(generator_mutex_); status = RandomNormalCompute(mean_, scale_, generator_, dtype, *Y); return status; @@ -190,7 +190,7 @@ Status RandomUniformLike::Compute(OpKernelContext* ctx) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not infer data type from input tensor with data type ", X.DataType()); - std::lock_guard l(generator_mutex_); + std::lock_guard l(generator_mutex_); status = RandomUniformCompute(low_, high_, generator_, dtype, *Y); return status; @@ -310,7 +310,7 @@ Status Multinomial::Compute(OpKernelContext* ctx) const { Tensor* Y = ctx->Output(0, {batch_size, num_samples_}); Status status = Status::OK(); - std::lock_guard l(generator_mutex_); + std::lock_guard l(generator_mutex_); switch (output_dtype_) { case TensorProto::INT32: { status = MultinomialCompute(ctx, X, batch_size, num_classes, num_samples_, generator_, *Y); diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 8a0390fe7af8c..1cfb276052f85 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -9,7 +9,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/framework/random_seed.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -58,7 +58,7 @@ class RandomNormal final : public OpKernel { // use generator_mutex_ to ensure Compute() can be called concurrently. // this is to ensure that a model with random generators is deterministic and still can be executed in parallel. mutable std::default_random_engine generator_; - mutable onnxruntime::OrtMutex generator_mutex_; + mutable std::mutex generator_mutex_; ONNX_NAMESPACE::TensorProto::DataType dtype_; TensorShape shape_; }; @@ -94,7 +94,7 @@ class RandomNormalLike final : public OpKernel { // see comments for generator_ and generator_mutex_ in RandomNormal class. mutable std::default_random_engine generator_; - mutable onnxruntime::OrtMutex generator_mutex_; + mutable std::mutex generator_mutex_; ONNX_NAMESPACE::TensorProto::DataType dtype_ = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; // optional and may be inferred }; @@ -132,7 +132,7 @@ class RandomUniform final : public OpKernel { // see comments for generator_ and generator_mutex_ in RandomNormal class. mutable std::default_random_engine generator_; - mutable onnxruntime::OrtMutex generator_mutex_; + mutable std::mutex generator_mutex_; ONNX_NAMESPACE::TensorProto::DataType dtype_; TensorShape shape_; }; @@ -167,7 +167,7 @@ class RandomUniformLike final : public OpKernel { // see comments for generator_ and generator_mutex_ in RandomNormal class. mutable std::default_random_engine generator_; - mutable onnxruntime::OrtMutex generator_mutex_; + mutable std::mutex generator_mutex_; ONNX_NAMESPACE::TensorProto::DataType dtype_ = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; // optional and may be inferred }; @@ -201,7 +201,7 @@ class Multinomial final : public OpKernel { // see comments for generator_ and generator_mutex_ in RandomNormal class. mutable std::default_random_engine generator_; - mutable onnxruntime::OrtMutex generator_mutex_; + mutable std::mutex generator_mutex_; ONNX_NAMESPACE::TensorProto::DataType output_dtype_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 2f4ebeabe043e..3359b2a69fe83 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -20,44 +20,48 @@ enum class OUTPUT_MODE { ALL_SCORES }; -enum NODE_MODE : uint8_t { - LEAF = 1, - BRANCH_LEQ = 2, - BRANCH_LT = 4, - BRANCH_GTE = 6, - BRANCH_GT = 8, - BRANCH_EQ = 10, - BRANCH_NEQ = 12 +enum NODE_MODE_ONNX : uint8_t { + BRANCH_LEQ = 0, + BRANCH_LT = 1, + BRANCH_GTE = 2, + BRANCH_GT = 3, + BRANCH_EQ = 4, + BRANCH_NEQ = 5, + BRANCH_MEMBER = 6, + LEAF = 7, }; -static inline NODE_MODE MakeTreeNodeMode(const std::string& input) { +static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) { if (input == "BRANCH_LEQ") { - return NODE_MODE::BRANCH_LEQ; + return NODE_MODE_ONNX::BRANCH_LEQ; } if (input == "LEAF") { - return NODE_MODE::LEAF; + return NODE_MODE_ONNX::LEAF; } if (input == "BRANCH_LT") { - return NODE_MODE::BRANCH_LT; + return NODE_MODE_ONNX::BRANCH_LT; } if (input == "BRANCH_GTE") { - return NODE_MODE::BRANCH_GTE; + return NODE_MODE_ONNX::BRANCH_GTE; } if (input == "BRANCH_GT") { - return NODE_MODE::BRANCH_GT; + return NODE_MODE_ONNX::BRANCH_GT; } if (input == "BRANCH_EQ") { - return NODE_MODE::BRANCH_EQ; + return NODE_MODE_ONNX::BRANCH_EQ; } - return NODE_MODE::BRANCH_NEQ; + if (input == "BRANCH_MEMBER") { + return NODE_MODE_ONNX::BRANCH_MEMBER; + } + return NODE_MODE_ONNX::BRANCH_NEQ; } -enum class POST_EVAL_TRANSFORM { - NONE, - LOGISTIC, - SOFTMAX, - SOFTMAX_ZERO, - PROBIT +enum class POST_EVAL_TRANSFORM : int64_t { + NONE = 0, + LOGISTIC = 1, + SOFTMAX = 2, + SOFTMAX_ZERO = 3, + PROBIT = 4 }; static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { @@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { return POST_EVAL_TRANSFORM::PROBIT; } -enum class AGGREGATE_FUNCTION { - AVERAGE, - SUM, - MIN, - MAX +enum class AGGREGATE_FUNCTION : int64_t { + AVERAGE = 0, + SUM = 1, + MIN = 2, + MAX = 3 }; static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc new file mode 100644 index 0000000000000..3ff501d96b72d --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/ml/tree_ensemble.h" +#include "core/providers/cpu/ml/tree_ensemble_helper.h" +#include "core/common/inlined_containers_fwd.h" + +namespace onnxruntime { +namespace ml { + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +template +TreeEnsemble::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) { + if constexpr (std::is_same::value) { + p_tree_ensemble_ = std::make_unique>(); + } else { + p_tree_ensemble_ = std::make_unique>(); + } + ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); +} + +template +Status TreeEnsemble::GetRemovableAttributes(InlinedVector& removable_attributes) const { + InlinedVector names{ + "leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs", + "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true", + "nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"}; + removable_attributes.swap(names); + return Status::OK(); +} + +template +common::Status TreeEnsemble::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (X->Shape().NumDimensions() == 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape needs to be at least a single dimension."); + } + int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0]; + Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()}); + return p_tree_ensemble_->compute(context, X, Y, NULL); +} + +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h new file mode 100644 index 0000000000000..697aae045a7e3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tree_ensemble_common.h" + +namespace onnxruntime { +namespace ml { +template +class TreeEnsemble final : public OpKernel { + typedef T InputType; // input type + typedef float OutputType; // output type + public: + explicit TreeEnsemble(const OpKernelInfo& info); + common::Status Compute(OpKernelContext* context) const override; + Status GetRemovableAttributes(InlinedVector& removable_attributes) const override; + + private: + // Pointer on one instance of + // detail::TreeEnsembleCommonV5 + // where ThresholdType is defined after accessing the attributes. + std::unique_ptr p_tree_ensemble_; +}; +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index b031a6f0cefa3..bf3fd37d10f5c 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -78,6 +78,40 @@ union PtrOrWeight { } weight_data; }; +enum NODE_MODE_ORT : uint8_t { + LEAF = 1, + BRANCH_LEQ = 2, + BRANCH_LT = 4, + BRANCH_GTE = 6, + BRANCH_GT = 8, + BRANCH_EQ = 10, + BRANCH_NEQ = 12, + BRANCH_MEMBER = 14, +}; + +inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) { + switch (node_mode) { + case NODE_MODE_ONNX::LEAF: + return NODE_MODE_ORT::LEAF; + case NODE_MODE_ONNX::BRANCH_LEQ: + return NODE_MODE_ORT::BRANCH_LEQ; + case NODE_MODE_ONNX::BRANCH_LT: + return NODE_MODE_ORT::BRANCH_LT; + case NODE_MODE_ONNX::BRANCH_GTE: + return NODE_MODE_ORT::BRANCH_GTE; + case NODE_MODE_ONNX::BRANCH_GT: + return NODE_MODE_ORT::BRANCH_GT; + case NODE_MODE_ONNX::BRANCH_EQ: + return NODE_MODE_ORT::BRANCH_EQ; + case NODE_MODE_ONNX::BRANCH_NEQ: + return NODE_MODE_ORT::BRANCH_NEQ; + case NODE_MODE_ONNX::BRANCH_MEMBER: + return NODE_MODE_ORT::BRANCH_MEMBER; + default: + ORT_THROW("Unexpected value for node_mode"); + }; +} + template struct TreeNodeElement { int feature_id; @@ -98,10 +132,10 @@ struct TreeNodeElement { // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also // stored in `value_or_unique_weight`. PtrOrWeight truenode_or_weight; - uint8_t flags; + NODE_MODE_ORT flags; - inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } - inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); } + inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); } + inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); } inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; } }; diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h new file mode 100644 index 0000000000000..d2d1ba9863ac7 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "ml_common.h" +#include "tree_ensemble_helper.h" +#include + +namespace onnxruntime { +namespace ml { +namespace detail { + +inline bool _isnan_(float x) { return std::isnan(x); } +inline bool _isnan_(double x) { return std::isnan(x); } +inline bool _isnan_(int64_t) { return false; } +inline bool _isnan_(int32_t) { return false; } + +template +struct TreeEnsembleAttributesV3 { + TreeEnsembleAttributesV3() {} + TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); + if (classifier) { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor)); + } else { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor)); + } +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", "SUM"); + base_values = info.GetAttrsOrDefault("base_values"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + + std::vector nodes_modes_string = info.GetAttrsOrDefault("nodes_modes"); + nodes_modes.reserve(nodes_modes_string.size()); + for (auto s : nodes_modes_string) { + nodes_modes.emplace_back(MakeTreeNodeMode(s)); + } + + nodes_nodeids = info.GetAttrsOrDefault("nodes_nodeids"); + nodes_treeids = info.GetAttrsOrDefault("nodes_treeids"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + nodes_values = info.GetAttrsOrDefault("nodes_values"); + post_transform = info.GetAttrOrDefault("post_transform", "NONE"); + + if (classifier) { + target_class_ids = info.GetAttrsOrDefault("class_ids"); + target_class_nodeids = info.GetAttrsOrDefault("class_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("class_treeids"); + target_class_weights = info.GetAttrsOrDefault("class_weights"); + classlabels_strings = info.GetAttrsOrDefault("classlabels_strings"); + classlabels_int64s = info.GetAttrsOrDefault("classlabels_int64s"); + n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size() + : classlabels_strings.size(); + } else { + n_targets_or_classes = info.GetAttrOrDefault("n_targets", 0); + target_class_ids = info.GetAttrsOrDefault("target_ids"); + target_class_nodeids = info.GetAttrsOrDefault("target_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("target_treeids"); + target_class_weights = info.GetAttrsOrDefault("target_weights"); + + ORT_ENFORCE(n_targets_or_classes > 0); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || + nodes_falsenodeids.size() == nodes_values_as_tensor.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); + ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); + ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); + ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); + ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); + ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); + ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits::max()); + } + } + + std::string aggregate_function; + std::vector base_values; + std::vector base_values_as_tensor; + int64_t n_targets_or_classes; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_hitrates_as_tensor; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_nodeids; + std::vector nodes_treeids; + std::vector nodes_truenodeids; + std::vector nodes_values; + std::vector nodes_values_as_tensor; + std::string post_transform; + std::vector target_class_ids; + std::vector target_class_nodeids; + std::vector target_class_treeids; + std::vector target_class_weights; + std::vector target_class_weights_as_tensor; + std::vector classlabels_strings; + std::vector classlabels_int64s; + std::vector class_labels; +}; + +template +struct TreeEnsembleAttributesV5 { + TreeEnsembleAttributesV5() {} + TreeEnsembleAttributesV5(const OpKernelInfo& info) { +#if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_modes_i; + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits)); + nodes_modes.reserve(nodes_modes.size()); + for (auto i : nodes_modes_i) { + nodes_modes.push_back(static_cast(i)); + } +#else + // GetVectorAttrsOrDefault is not part of the minimal build. + // As a result, TreeEnsemble v5 cannot be available in this build. + ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build."); +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", 1); + leaf_targetids = info.GetAttrsOrDefault("leaf_targetids"); + n_targets = info.GetAttrOrDefault("n_targets", 0); + nodes_falseleafs = info.GetAttrsOrDefault("nodes_falseleafs"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + nodes_trueleafs = info.GetAttrsOrDefault("nodes_trueleafs"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + post_transform = info.GetAttrOrDefault("post_transform", 0); + tree_roots = info.GetAttrsOrDefault("tree_roots"); + } + + void convert_to_v3(TreeEnsembleAttributesV3& output) const { + // Doing all transformations to get the old format. + output.n_targets_or_classes = n_targets; + output.aggregate_function = aggregateFunctionToString(); + output.post_transform = postTransformToString(); + std::vector> membership_values_by_id; + getMembershipValuesById(membership_values_by_id); + transformInputAllTrees(output, membership_values_by_id); + } + + int64_t aggregate_function; + std::vector leaf_targetids; + std::vector leaf_weights; + std::vector membership_values; + int64_t n_targets; + std::vector nodes_falseleafs; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_splits; + std::vector nodes_trueleafs; + std::vector nodes_truenodeids; + int64_t post_transform; + std::vector tree_roots; + + private: + // `membership_values` are seperated by NAN for different nodes + // It is more convenient to preserve the values for each node in a vector + // The vector would be empty for nodes that are not `BRANCH_MEMBER` + void getMembershipValuesById(std::vector>& membership_values_by_id) const { + membership_values_by_id.clear(); + membership_values_by_id.reserve(nodes_modes.size()); + + size_t curr_id = 0; + for (const auto node_mode : nodes_modes) { + membership_values_by_id.emplace_back(); + if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) { + continue; + } + + while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) { + membership_values_by_id.back().push_back(membership_values[curr_id++]); + } + curr_id++; + } + } + + std::string aggregateFunctionToString() const { + switch (aggregate_function) { + case static_cast(AGGREGATE_FUNCTION::AVERAGE): + return "AVERAGE"; + case static_cast(AGGREGATE_FUNCTION::SUM): + return "SUM"; + case static_cast(AGGREGATE_FUNCTION::MIN): + return "MIN"; + case static_cast(AGGREGATE_FUNCTION::MAX): + return "MAX"; + default: + ORT_THROW("Unknown value for aggregate_function."); + } + } + + std::string postTransformToString() const { + switch (post_transform) { + case static_cast(POST_EVAL_TRANSFORM::NONE): + return "NONE"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX): + return "SOFTMAX"; + case static_cast(POST_EVAL_TRANSFORM::LOGISTIC): + return "LOGISTIC"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX_ZERO): + return "SOFTMAX_ZERO"; + case static_cast(POST_EVAL_TRANSFORM::PROBIT): + return "PROBIT"; + default: + ORT_THROW("Unknown value for post_transform."); + } + } + + int64_t transformInputOneTree( + const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id, + const bool is_leaf, std::vector>& membership_values_by_id, + TreeEnsembleAttributesV3& output) const { + output.nodes_nodeids.push_back(curr_nodeid); + output.nodes_treeids.push_back(curr_treeid); + + if (is_leaf) { + output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF); + output.target_class_ids.push_back(leaf_targetids[curr_id]); + output.target_class_nodeids.push_back(curr_nodeid); + output.target_class_treeids.push_back(curr_treeid); + output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]); + + // the below are irrelevant for a `LEAF` + output.nodes_featureids.push_back(0); + output.nodes_truenodeids.push_back(0); + output.nodes_falsenodeids.push_back(0); + output.nodes_values_as_tensor.push_back(0); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates.push_back(0); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(0); + } + + return curr_nodeid; + } + + output.nodes_featureids.push_back(nodes_featureids[curr_id]); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]); + } + + // unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ` + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) { + output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ); + output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]); + } else { + output.nodes_modes.push_back(nodes_modes[curr_id]); + output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]); + } + + size_t falsenodeid_id = output.nodes_falsenodeids.size(); + output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree + + int64_t true_nodeid = curr_nodeid + 1; + output.nodes_truenodeids.push_back(true_nodeid); + true_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_truenodeids[curr_id]), + curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0, + membership_values_by_id, output); + + int64_t false_nodeid = true_nodeid + 1; + output.nodes_falsenodeids[falsenodeid_id] = false_nodeid; + + // if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node + // therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value + // so in that case we are only moving the pointer for `membership_values` + // + // otherwise, the `falsenode_id` is pointing to the real falsenode subtree + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER && + curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) { + false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false, + membership_values_by_id, output); + } else { + false_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_falsenodeids[curr_id]), + curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0, + membership_values_by_id, output); + } + return false_nodeid; + } + + void transformInputAllTrees(TreeEnsembleAttributesV3& output, + std::vector>& membership_values_by_id) const { + int64_t curr_treeid = 0; + for (const int64_t& tree_root : tree_roots) { + size_t tree_root_size_t = onnxruntime::narrow(tree_root); + transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U, + nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t], + membership_values_by_id, output); + curr_treeid++; + } + } +}; + +} // namespace detail +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index df27f888bb0af..10d4db0e0e3b0 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -3,15 +3,21 @@ #pragma once -#include "tree_ensemble_aggregator.h" -#include "core/platform/ort_mutex.h" +#include #include "core/platform/threadpool.h" #include "tree_ensemble_helper.h" +#include "tree_ensemble_attribute.h" +#include "tree_ensemble_aggregator.h" namespace onnxruntime { namespace ml { namespace detail { +/** + * These attributes are the kernel attributes. They are different from the onnx operator attributes + * to improve the computation efficiency. The initialization consists in moving the onnx attributes + * into the kernel attributes. + */ class TreeEnsembleCommonAttributes { public: int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; } @@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { Status Init(int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes); protected: TreeNodeElement* ProcessTreeNodeLeave(TreeNodeElement* root, @@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; private: - size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, - int64_t tree_id, const InlinedVector& node_tree_ids); + bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices); + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices); }; +// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17 +// Remove it when that is not the case +template +std::enable_if_t< + sizeof(To) == sizeof(From) && + std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> + // constexpr support needs compiler magic + static bit_cast(const From& src) noexcept { + static_assert(std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +std::conditional_t bit_cast_int(T val) { + if constexpr (sizeof(T) == sizeof(uint32_t)) { + return bit_cast(val); + } else if constexpr (sizeof(T) == sizeof(uint64_t)) { + return bit_cast(val); + } + static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t)); +} + template Status TreeEnsembleCommon::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, target_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrOrDefault("n_targets", 0), - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("target_ids"), - info.GetAttrsOrDefault("target_nodeids"), - info.GetAttrsOrDefault("target_treeids"), - info.GetAttrsOrDefault("target_weights"), - target_weights_as_tensor); + TreeEnsembleAttributesV3 attributes(info, false); + return Init(80, 128, 50, attributes); } template @@ -137,72 +126,35 @@ Status TreeEnsembleCommon::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor) { + const TreeEnsembleAttributesV3& attributes) { parallel_tree_ = parallel_tree; parallel_tree_N_ = parallel_tree_N; parallel_N_ = parallel_N; - ORT_ENFORCE(n_targets_or_classes > 0); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || - nodes_falsenodeids.size() == nodes_values_as_tensor.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); - ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); - ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); - ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); - ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); - ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); - - aggregate_function_ = MakeAggregateFunction(aggregate_function); - post_transform_ = MakeTransform(post_transform); - if (!base_values_as_tensor.empty()) { - ORT_ENFORCE(base_values.empty()); - base_values_ = base_values_as_tensor; + aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function); + post_transform_ = MakeTransform(attributes.post_transform); + if (!attributes.base_values_as_tensor.empty()) { + ORT_ENFORCE(attributes.base_values.empty()); + base_values_ = attributes.base_values_as_tensor; } else { - base_values_.reserve(base_values.size()); - for (size_t i = 0, limit = base_values.size(); i < limit; ++i) { - base_values_.push_back(static_cast(base_values[i])); + base_values_.reserve(attributes.base_values.size()); + for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) { + base_values_.push_back(static_cast(attributes.base_values[i])); } } - n_targets_or_classes_ = n_targets_or_classes; + n_targets_or_classes_ = attributes.n_targets_or_classes; max_tree_depth_ = 1000; - ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); // Additional members size_t limit; uint32_t i; - InlinedVector cmodes; - cmodes.reserve(nodes_modes.size()); + InlinedVector cmodes; + cmodes.reserve(attributes.nodes_modes.size()); same_mode_ = true; int fpos = -1; - for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { - cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) continue; + for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) { + cmodes.push_back(attributes.nodes_modes[i]); + if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; @@ -210,7 +162,7 @@ Status TreeEnsembleCommon::Init( if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - n_nodes_ = nodes_treeids.size(); + n_nodes_ = attributes.nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; node_tree_ids.reserve(limit); @@ -227,7 +179,7 @@ Status TreeEnsembleCommon::Init( // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + TreeNodeElementId node_tree_id{static_cast(attributes.nodes_treeids[i]), static_cast(attributes.nodes_nodeids[i])}; auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); @@ -237,13 +189,13 @@ Status TreeEnsembleCommon::Init( TreeNodeElementId coor; for (i = 0; i < limit; ++i) { - if (cmodes[i] == NODE_MODE::LEAF) { + if (cmodes[i] == NODE_MODE_ONNX::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); } else { TreeNodeElementId& node_tree_id = node_tree_ids[i]; coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); + coor.node_id = static_cast(attributes.nodes_truenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); auto found = node_tree_ids_map.find(coor); @@ -255,7 +207,7 @@ Status TreeEnsembleCommon::Init( } truenode_ids.emplace_back(found->second); - coor.node_id = static_cast(nodes_falsenodeids[i]); + coor.node_id = static_cast(attributes.nodes_falsenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); found = node_tree_ids_map.find(coor); if (found == node_tree_ids_map.end()) { @@ -270,41 +222,38 @@ Status TreeEnsembleCommon::Init( } } + // Sort targets + InlinedVector> indices; + indices.reserve(attributes.target_class_nodeids.size()); + for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) { + indices.emplace_back( + TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i); + } + + std::sort(indices.begin(), indices.end()); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. // updated_mapping will translates the old position of each node to the new node position in nodes_. - std::vector updated_mapping(nodes_treeids.size(), 0); + std::vector updated_mapping(attributes.nodes_treeids.size(), 0); int64_t previous_tree_id = -1; for (i = 0; i < n_nodes_; ++i) { if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { // New tree. int64_t tree_id = node_tree_ids[i].tree_id; size_t root_position = - AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, - nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values, + attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices); roots_.push_back(&nodes_[root_position]); previous_tree_id = tree_id; } } - n_trees_ = roots_.size(); - if (((int64_t)nodes_.size()) != n_nodes_) { - ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); - } - - // Sort targets - InlinedVector> indices; - indices.reserve(target_class_nodeids.size()); - for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back( - std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); - } - - std::sort(indices.begin(), indices.end()); TreeNodeElementId ind; SparseValue w; size_t indi; - for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { + for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; auto found = node_tree_ids_map.find(ind); @@ -319,9 +268,10 @@ Status TreeEnsembleCommon::Init( // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; + w.i = attributes.target_class_ids[i]; + w.value = attributes.target_class_weights_as_tensor.empty() + ? static_cast(attributes.target_class_weights[i]) + : attributes.target_class_weights_as_tensor[i]; if (leaf.truenode_or_weight.weight_data.n_weights == 0) { leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; @@ -331,7 +281,7 @@ Status TreeEnsembleCommon::Init( } has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = attributes.nodes_missing_value_tracks_true.begin(); itm != attributes.nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; @@ -341,13 +291,58 @@ Status TreeEnsembleCommon::Init( return Status::OK(); } +template +bool TreeEnsembleCommon::CheckIfSubtreesAreEqual( + const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices) { + // Leaves have values set at 0 + if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || + (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || + (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) { + return false; + } + + if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) { + const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second; + const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second; + + if (target_class_weights_as_tensor.empty()) { + return target_class_weights[left_target_node] == target_class_weights[right_target_node]; + } else { + return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node]; + } + } + + return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) && + CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices); +} + +inline void UpdateThreshold(double val, double& mask) { + uint64_t new_mask = bit_cast(mask) | (1ll << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +inline void UpdateThreshold(float val, float& mask) { + uint32_t new_mask = bit_cast(mask) | (1 << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +#define BITCOUNT(T) int64_t(sizeof(T) * 8) +#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v) + template size_t TreeEnsembleCommon::AddNodes( - const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, - const InlinedVector& node_tree_ids) { + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices) { // Validate this index maps to the same tree_id as the one we should be building. if (node_tree_ids[i].tree_id != tree_id) { ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); @@ -364,28 +359,59 @@ size_t TreeEnsembleCommon::AddNodes( updated_mapping[i] = node_pos; TreeNodeElement node; - node.flags = static_cast(cmodes[i]); + node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]); node.feature_id = static_cast(nodes_featureids[i]); if (node.feature_id > max_feature_id_) { max_feature_id_ = node.feature_id; } - node.value_or_unique_weight = - nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + + node.value_or_unique_weight = 0; + const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) { + UpdateThreshold(node_threshold, node.value_or_unique_weight); + node.flags = NODE_MODE_ORT::BRANCH_MEMBER; + } else { + node.value_or_unique_weight = node_threshold; + } + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); + node.flags = static_cast(static_cast(node.flags) | static_cast(MissingTrack::kTrue)); } nodes_.push_back(std::move(node)); if (nodes_[node_pos].is_not_leaf()) { + size_t falsenode_id = falsenode_ids[i]; + + // Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain + // Below we are folding together these nodes into one of mode `BRANCH_MEMBER` + // The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain + // Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node + // and the one of the feature (the mask has only one bit set on the place for its value) + // Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done + if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) { + ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + + while (cmodes[falsenode_id] == NODE_MODE_ONNX::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] && + CANMASK(falsenode_threshold, ThresholdType) && + CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids, + nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) { + UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight); + falsenode_id = falsenode_ids[falsenode_id]; + falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + } + } + size_t false_branch = - AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); if (false_branch != node_pos + 1) { ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", static_cast(nodes_[node_pos].flags)); } size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; @@ -684,10 +710,12 @@ void TreeEnsembleCommon::ComputeAgg(concur } \ } -inline bool _isnan_(float x) { return std::isnan(x); } -inline bool _isnan_(double x) { return std::isnan(x); } -inline bool _isnan_(int64_t) { return false; } -inline bool _isnan_(int32_t) { return false; } +// Check whether the feature value is set true in the mask +template +inline bool SetMembershipCheck(T1 val, T2 mask) { + const int64_t val_as_int = static_cast(val); + return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0); +} template TreeNodeElement* @@ -696,7 +724,7 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( InputType val; if (same_mode_) { switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; @@ -711,22 +739,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: TREE_FIND_VALUE(<) break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: TREE_FIND_VALUE(>=) break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: TREE_FIND_VALUE(>) break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: TREE_FIND_VALUE(==) break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: TREE_FIND_VALUE(!=) break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + if (has_missing_tracks_) { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + } + } else { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1; + } + } + case NODE_MODE_ORT::LEAF: break; } } else { // Different rules to compare to node thresholds. @@ -735,31 +777,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + break; + case NODE_MODE_ORT::LEAF: return root; } } @@ -786,67 +833,13 @@ class TreeEnsembleCommonClassifier : public TreeEnsembleCommon& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s); + const TreeEnsembleAttributesV3& attributes); }; template Status TreeEnsembleCommonClassifier::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, class_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", class_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("class_ids"), - info.GetAttrsOrDefault("class_nodeids"), - info.GetAttrsOrDefault("class_treeids"), - info.GetAttrsOrDefault("class_weights"), - class_weights_as_tensor, - info.GetAttrsOrDefault("classlabels_strings"), - info.GetAttrsOrDefault("classlabels_int64s")); + TreeEnsembleAttributesV3 attributes(info, true); + return Init(80, 128, 50, attributes); } template @@ -854,65 +847,20 @@ Status TreeEnsembleCommonClassifier::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s) { - auto status = TreeEnsembleCommon::Init( - parallel_tree, - parallel_tree_N, - parallel_N, - aggregate_function, - base_values, - base_values_as_tensor, - classlabels_strings.empty() ? classlabels_int64s.size() - : classlabels_strings.size(), - nodes_falsenodeids, - nodes_featureids, - nodes_hitrates, - nodes_hitrates_as_tensor, - nodes_missing_value_tracks_true, - nodes_modes, - nodes_nodeids, - nodes_treeids, - nodes_truenodeids, - nodes_values, - nodes_values_as_tensor, - post_transform, - class_ids, - class_nodeids, - class_treeids, - class_weights, - class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes) { + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes); ORT_RETURN_IF_ERROR(status); - classlabels_strings_ = classlabels_strings; - classlabels_int64s_ = classlabels_int64s; + classlabels_strings_ = attributes.classlabels_strings; + classlabels_int64s_ = attributes.classlabels_int64s; InlinedHashSet weights_classes; - weights_classes.reserve(class_ids.size()); + weights_classes.reserve(attributes.target_class_ids.size()); weights_are_all_positive_ = true; - for (size_t i = 0, end = class_ids.size(); i < end; ++i) { - weights_classes.insert(class_ids[i]); - if (weights_are_all_positive_ && (!class_weights.empty() ? class_weights[i] : class_weights_as_tensor[i]) < 0) + for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) { + weights_classes.insert(attributes.target_class_ids[i]); + if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i] + : attributes.target_class_weights_as_tensor[i]) < 0) weights_are_all_positive_ = false; } binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1; @@ -957,6 +905,43 @@ Status TreeEnsembleCommonClassifier::compu return Status::OK(); } +template +class TreeEnsembleCommonV5 : public TreeEnsembleCommon { + public: + virtual Status Init(const OpKernelInfo& info); + + Status Init(int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes); +}; + +template +Status TreeEnsembleCommonV5::Init(const OpKernelInfo& info) { + TreeEnsembleAttributesV5 attributes(info); + return Init(80, 128, 50, attributes); +} + +template +Status TreeEnsembleCommonV5::Init( + int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes) { + TreeEnsembleAttributesV3 attributes_v3; + attributes.convert_to_v3(attributes_v3); + + attributes_v3.base_values.clear(); + attributes_v3.base_values_as_tensor.clear(); + attributes_v3.nodes_hitrates.clear(); + attributes_v3.nodes_values.clear(); + attributes_v3.target_class_weights.clear(); + + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes_v3); + ORT_RETURN_IF_ERROR(status); + return Status::OK(); +} + } // namespace detail } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc index e2981da3a6f25..399dfd56b93c6 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc @@ -5,63 +5,53 @@ #include "core/providers/cpu/ml/tree_ensemble_helper.h" #include "core/common/common.h" +#include "core/common/safeint.h" #include "onnx/defs/tensor_proto_util.h" +#include "core/framework/tensorprotoutils.h" using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { namespace ml { -Status GetNumberOfElementsAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - size_t& n_elements, ONNX_NAMESPACE::TensorProto& proto) { - auto status = info.GetAttr(name, &proto); - if (!status.IsOK()) { - // Attribute is missing, n_elements is set to 0. - n_elements = 0; - return Status::OK(); - } - auto n_dims = proto.dims_size(); - if (n_dims == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attribute:'", name, "' is specified but is empty."); - } - ORT_ENFORCE(n_dims == 1, "Attribute '", name, "' must be a vector."); - ORT_ENFORCE(proto.data_type() == proto_type, - "Unexpected type (", proto.data_type(), "(for attribute '", name, "'."); - - n_elements = onnxruntime::narrow(proto.dims()[0]); - ORT_ENFORCE(n_elements > 0, "Attribute '", name, "' has one dimension but is empty."); - return Status::OK(); -} +template +Status GetAnyVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + ONNX_NAMESPACE::TensorProto proto; + auto result = info.GetAttr(name, &proto); -template -Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, std::vector& data) { - if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE) { - ORT_ENFORCE((std::is_same::value)); - } else if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { - ORT_ENFORCE((std::is_same::value)); - } else { - ORT_NOT_IMPLEMENTED("GetVectorAttrsOrDefault not implemented for type ", proto_type); + SafeInt n_elements(1); + for (auto dim : proto.dims()) { + n_elements *= dim; } - ONNX_NAMESPACE::TensorProto proto; - size_t n_elements; - data.clear(); - ORT_THROW_IF_ERROR(GetNumberOfElementsAttrsOrDefault(info, name, proto_type, n_elements, proto)); - if (n_elements == 0) { + if (proto.dims().empty()) { return Status::OK(); } - data = ONNX_NAMESPACE::ParseData(&proto); + + const SafeInt tensor_size(n_elements); + data.clear(); + data.resize(tensor_size); + + result = utils::UnpackTensor(proto, std::filesystem::path(), data.data(), tensor_size); + ORT_ENFORCE(result.IsOK(), "TreeEnsemble could not unpack tensor attribute ", name); + return Status::OK(); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE, data); + return GetAnyVectorAttrsOrDefault(info, name, data); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, data); + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); } } // namespace ml diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h index 33172c343a88e..ba23f1ad28ec1 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h @@ -13,6 +13,8 @@ namespace ml { Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 23630dcb63efa..24a5dcab225c4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/framework/tensor.h" +#include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/util/force_inline.h" @@ -12,90 +13,178 @@ namespace onnxruntime { -// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. -template -ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); +namespace { -template <> -ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return val.ToFloat(); -} +template || std::is_same_v, void>> +void ComputeJob( + const T* X_data, + const T* scale_data, + const T* bias_data, + const ptrdiff_t task_idx, + const int64_t norm_size, + const float* scale_float_ptr, + const float* bias_float_ptr, + float epsilon, + bool simplified, + T* Y_data, + U* mean_data, + U* inv_std_dev_data, + AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(scale_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(alloc); + + const T* p_input = X_data + task_idx * norm_size; + T* p_output = Y_data + task_idx * norm_size; + + T mean(0.0f); + T mean_square(0.0f); + + for (int64_t h = 0; h < norm_size; h++) { + p_output[h] = p_input[h]; + mean += p_input[h]; + mean_square += p_input[h] * p_input[h]; + } -template <> -ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return double(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); -} + mean = mean / norm_size; + if (simplified) { + mean_square = sqrt(mean_square / norm_size + epsilon); + } else { + mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); + } -template <> -ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { - return val; -} + for (int64_t h = 0; h < norm_size; h++) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * scale_data[h]; + } else if (nullptr == bias_data) { + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h]; + } else { + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h]; + } + } -template <> -ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { - return val; -} + if (mean_data != nullptr) { + // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow + mean_data[task_idx] = gsl::narrow_cast(mean); + } -ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(float val) { - return val; + if (inv_std_dev_data != nullptr) { + inv_std_dev_data[task_idx] = gsl::narrow_cast(1 / mean_square); + } } -ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(double val) { - // ONNX spec doesn't support 'double' for 'Ret' so when 'T' == double, 'Ret' == float and we need to narrow - return gsl::narrow_cast(val); -} +template +void ComputeJob( + const MLFloat16* X_data, + const MLFloat16* scale_data, + const MLFloat16* bias_data, + const ptrdiff_t task_idx, + const int64_t norm_size, + const float* scale_float_ptr, + const float* bias_float_ptr, + float epsilon, + bool simplified, + MLFloat16* Y_data, + U* mean_data, + U* inv_std_dev_data, + AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(scale_data); // only used in float/double overload + ORT_UNUSED_PARAMETER(bias_data); // only used in float/double overload + + const MLFloat16* p_input = X_data + task_idx * norm_size; + MLFloat16* p_output = Y_data + task_idx * norm_size; + + float mean(0.0f); + float mean_square(0.0f); + + const size_t num_elems = static_cast(norm_size); + IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); + + IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + float* output_float_ptr = output_float_uptr.get(); + + const float* input_float_ptr = input_float_uptr.get(); + for (size_t h = 0; h < num_elems; h++) { + output_float_ptr[h] = input_float_ptr[h]; + mean += input_float_ptr[h]; + mean_square += input_float_ptr[h] * input_float_ptr[h]; + } -// Function template that only converts the input value to MLFloat16 if T is MLFloat16. -template -ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, float> -ConvertToMLFloat16IfNeeded(float val) { - return val; -} + mean = mean / norm_size; + if (simplified) { + mean_square = sqrt(mean_square / norm_size + epsilon); + } else { + mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); + } + + for (size_t h = 0; h < num_elems; h++) { + if (simplified) { + output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h]; + } else if (nullptr == bias_float_ptr) { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h]; + } else { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h]; + } + } + + MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, MLFloat16> -ConvertToMLFloat16IfNeeded(float val) { - return MLFloat16(val); + if (mean_data != nullptr) { + // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow + mean_data[task_idx] = MLFloat16(mean); + } + + if (inv_std_dev_data != nullptr) { + inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square); + } } -template -ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) { - return val; +void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { + if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { + auto tensor_data_ptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto float_ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); + + MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size); + dest = std::move(float_ptr); + is_packed = true; + } } +} // namespace + LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op) - : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op} { + : OpKernel(op_kernel_info), + simplified_{simplified}, + contrib_op_{contrib_op}, + prepacked_scale_fp32_data_(nullptr), + prepacked_scale_fp32_size_(0), + prepacked_bias_fp32_data_(nullptr), + prepacked_bias_fp32_size_(0) { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); } -namespace { template -Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) { +Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const { // Inputs const Tensor* X = p_ctx->Input(0); - const Tensor* scale = p_ctx->Input(1); - const Tensor* bias = p_ctx->Input(2); + const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(2); const T* X_data = X->Data(); - const T* scale_data = scale->Data(); + const T* scale_data = scale ? scale->Data() : nullptr; const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); - const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); - int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); - int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); - - const auto scale_size = scale->Shape().Size(); - const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; - if (scale_size != norm_size || (bias_data && bias_size != norm_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", norm_size, - ". Size of scale and bias (if provided) must match this. Got scale size of ", - scale_size, " and bias size of ", bias_size); - } - + size_t scale_size = scale ? static_cast(scale->Shape().Size()) : prepacked_scale_fp32_size_; + size_t bias_size = bias ? static_cast(bias->Shape().Size()) : prepacked_bias_fp32_size_; Tensor* Y = p_ctx->Output(0, x_shape); - auto Y_data = Y->MutableData(); + T* Y_data = Y->MutableData(); + + const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); std::vector mean_inv_std_dev_dim; mean_inv_std_dev_dim.reserve(x_shape.NumDimensions()); @@ -107,11 +196,7 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo } } - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - int output_index = 1; - U* mean_data = nullptr; if (!simplified) { Tensor* mean = p_ctx->Output(output_index++, TensorShape(mean_inv_std_dev_dim)); @@ -126,87 +211,91 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo inv_std_dev_data = inv_std_dev->MutableData(); } - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(norm_count), - [&](ptrdiff_t task_idx) { - const T* p_input = X_data + task_idx * norm_size; - T* p_output = Y_data + task_idx * norm_size; - - using DoubleOrFloat = typename std::conditional< - std::is_same::value, // If T is double - double, // Use double - float // Otherwise, use float (covers float and MLFloat16) - >::type; - - DoubleOrFloat mean(0.0f); - DoubleOrFloat mean_square(0.0f); - - for (int64_t h = 0; h < norm_size; h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - mean += input_value; - mean_square += input_value * input_value; - } - - mean = mean / norm_size; - if (simplified) { - mean_square = sqrt(mean_square / norm_size + epsilon); - } else { - mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); - } - - for (int64_t h = 0; h < norm_size; h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - DoubleOrFloat scale_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(scale_data[h]); - if (simplified) { - p_output[h] = ConvertToMLFloat16IfNeeded(input_value / mean_square * scale_value); - } else if (nullptr == bias) { - p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value); - } else { - DoubleOrFloat bias_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); - p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value + bias_value); - } - } - - if (mean_data != nullptr) { - // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow - mean_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(mean)); - } - - if (inv_std_dev_data != nullptr) { - inv_std_dev_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(1 / mean_square)); - } - }, - 0); + onnxruntime::concurrency::ThreadPool* thread_pool = p_ctx->GetOperatorThreadPool(); - return Status::OK(); + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + return ComputeWithoutContext(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data, + inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc); } -template -struct SrcDispatcher { - Status operator()(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified, bool contrib_op) const { - // the contrib op kernel was always registered with the same type for all constraints. - // our implementation of the onnx op only supports 'float' as the U constraint. -#if !defined(DISABLE_CONTRIB_OPS) - if (contrib_op) { - return ComputeImpl(p_ctx, orig_axis, epsilon, simplified); - } else -#else - ORT_UNUSED_PARAMETER(contrib_op); -#endif - { - return ComputeImpl(p_ctx, orig_axis, epsilon, simplified); - } - } -}; -} // namespace - Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { const auto elem_type = p_ctx->Input(0)->GetElementType(); using SupportedTypeList = boost::mp11::mp_list; utils::MLTypeCallDispatcherFromTypeList t_disp(elem_type); - return t_disp.InvokeRet(p_ctx, axis_, epsilon_, simplified_, contrib_op_); + return t_disp.InvokeRet(this, p_ctx, axis_, epsilon_, simplified_, contrib_op_); +} + +Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + is_packed = false; + if (input_idx == 1) { // scale + prepacked_scale_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed); + } else if (input_idx == 2) { // bias + prepacked_bias_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } + + return Status::OK(); +} + +template +Status LayerNormImpl::ComputeWithoutContext( + const T* X_data, + const TensorShape& x_shape, + const T* scale_data, + size_t scale_size, + const T* bias_data, + size_t bias_size, + T* Y_data, + U* mean_data, + U* inv_std_dev_data, + onnxruntime::concurrency::ThreadPool* thread_pool, + int64_t axis, + float epsilon, + bool simplified, + AllocatorPtr alloc) const { + int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); + int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); + + if (static_cast(scale_size) != norm_size || (bias_data && static_cast(bias_size) != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + IAllocatorUniquePtr scale_fp32; + IAllocatorUniquePtr bias_fp32; + if constexpr (std::is_same_v) { + if (prepacked_scale_fp32_data_ == nullptr) { + const size_t num_elems = static_cast(norm_size); + scale_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems); + } + if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + const size_t num_elems = static_cast(norm_size); + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + } + } + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, static_cast(norm_count), + [&](ptrdiff_t task_idx) { + ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, + prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(), + prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), + epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); + }, + 0); + + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index 393c637dbda18..f8b528b398cba 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -4,6 +4,7 @@ #pragma once #include "core/common/common.h" +#include "core/framework/allocator.h" #include "core/framework/op_kernel.h" #include "core/framework/tensor.h" @@ -14,11 +15,58 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + + // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. + template + Status ComputeWithoutContext( + const T* X_data, + const TensorShape& x_shape, + const T* scale_data, + size_t scale_size, + const T* bias_data, + size_t bias_size, + T* Y_data, + U* mean_data, + U* inv_std_dev, + onnxruntime::concurrency::ThreadPool* thread_pool, + int64_t axis, + float epsilon, + bool simplified, + AllocatorPtr alloc) const; + private: + template + Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const; + + template + struct SrcDispatcher { + Status operator()(const LayerNormImpl* p_instance, OpKernelContext* p_ctx, int64_t orig_axis, + float epsilon, bool simplified, bool contrib_op) const { + // the contrib op kernel was always registered with the same type for all constraints. + // our implementation of the onnx op only supports 'float' as the U constraint. +#if !defined(DISABLE_CONTRIB_OPS) + if (contrib_op) { + return p_instance->ComputeImpl(p_ctx, orig_axis, epsilon, simplified); + } else +#else + ORT_UNUSED_PARAMETER(contrib_op); +#endif + { + return p_instance->ComputeImpl(p_ctx, orig_axis, epsilon, simplified); + } + } + }; + int64_t axis_; float epsilon_; const bool simplified_; const bool contrib_op_; + IAllocatorUniquePtr prepacked_scale_fp32_data_; + size_t prepacked_scale_fp32_size_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; + size_t prepacked_bias_fp32_size_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc index cb162ade44559..be448455194f6 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc @@ -14,10 +14,11 @@ namespace onnxruntime { // uint8_t kernel supports weight being either uint8_t or int8_t -ONNX_OPERATOR_TYPED_KERNEL_EX( +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( QLinearMatMul, kOnnxDomain, 10, + 20, uint8_t, kCpuExecutionProvider, KernelDefBuilder() @@ -26,21 +27,45 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), QLinearMatMul); +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + // int8_t kernel only supports weight being int8_t -#define REGISTER_QLINEARMATMUL_INT8_KERNEL() \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - QLinearMatMul, \ - kOnnxDomain, \ - 10, \ - int8_t, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", DataTypeImpl::GetTensorType()), \ - QLinearMatMul); - -REGISTER_QLINEARMATMUL_INT8_KERNEL(); +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 10, + 20, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); Status QLinearMatMul::Compute(OpKernelContext* ctx) const { const auto* a = ctx->Input(IN_A); diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h index 4b31e3a82f2d0..6960f8838ffde 100644 --- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h +++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h @@ -20,15 +20,6 @@ class UnsqueezeBase { }; Status PrepareCompute(OpKernelContext* context, Prepare& p) const; - - protected: - UnsqueezeBase(const OpKernelInfo& info) { - size_t num_inputs = info.GetInputCount(); - if (num_inputs == 1) { // axes must be a valid attribute - ORT_ENFORCE(info.GetAttrs("axes", axes_).IsOK(), "Missing/Invalid 'axes' attribute value"); - } - } - static TensorShapeVector ComputeOutputShape( const TensorShape& input_shape, const TensorShapeVector& axes) { @@ -59,6 +50,14 @@ class UnsqueezeBase { return output_shape; } + protected: + UnsqueezeBase(const OpKernelInfo& info) { + size_t num_inputs = info.GetInputCount(); + if (num_inputs == 1) { // axes must be a valid attribute + ORT_ENFORCE(info.GetAttrs("axes", axes_).IsOK(), "Missing/Invalid 'axes' attribute value"); + } + } + TensorShapeVector axes_; }; diff --git a/onnxruntime/core/providers/cpu/text/string_normalizer.cc b/onnxruntime/core/providers/cpu/text/string_normalizer.cc index 32de3105d627d..9bc671f68f19a 100644 --- a/onnxruntime/core/providers/cpu/text/string_normalizer.cc +++ b/onnxruntime/core/providers/cpu/text/string_normalizer.cc @@ -8,6 +8,7 @@ #include "onnxruntime_config.h" #ifdef _MSC_VER +#include #include #endif // _MSC_VER diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.cc b/onnxruntime/core/providers/cuda/cuda_allocator.cc index 2189af8e0ee2d..8c96d8f57a0ba 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.cc +++ b/onnxruntime/core/providers/cuda/cuda_allocator.cc @@ -69,7 +69,7 @@ void* CUDAExternalAllocator::Alloc(size_t size) { void CUDAExternalAllocator::Free(void* p) { free_(p); - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); @@ -80,7 +80,7 @@ void CUDAExternalAllocator::Free(void* p) { void* CUDAExternalAllocator::Reserve(size_t size) { void* p = Alloc(size); if (!p) return nullptr; - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); ORT_ENFORCE(reserved_.find(p) == reserved_.end()); reserved_.insert(p); return p; diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.h b/onnxruntime/core/providers/cuda/cuda_allocator.h index 86d0d8007bbd8..2d94e2b1cda89 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.h +++ b/onnxruntime/core/providers/cuda/cuda_allocator.h @@ -5,7 +5,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -42,7 +42,7 @@ class CUDAExternalAllocator : public CUDAAllocator { void* Reserve(size_t size) override; private: - mutable OrtMutex lock_; + mutable std::mutex lock_; ExternalAlloc alloc_; ExternalFree free_; ExternalEmptyCache empty_cache_; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 82b29c7b0562e..d4013a7dc3d57 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -50,7 +50,6 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? // The cudaMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); @@ -324,7 +323,7 @@ DataLayout CUDAExecutionProvider::GetPreferredLayout() const { CUDAExecutionProvider::~CUDAExecutionProvider() { // clean up thread local context caches { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { const auto cache = cache_weak.lock(); if (!cache) continue; @@ -369,7 +368,7 @@ CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadCont // get context and update cache std::shared_ptr context; { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); // get or create a context if (context_state_.retired_context_pool.empty()) { @@ -406,7 +405,7 @@ void CUDAExecutionProvider::ReleasePerThreadContext() const { ORT_ENFORCE(cached_context); { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); context_state_.active_contexts.erase(cached_context); context_state_.retired_context_pool.push_back(cached_context); } @@ -964,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); + // OpSet 13 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); @@ -1200,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); + // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu); @@ -1641,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1823,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1917,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2151,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 14 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2567,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // CuDNN doesn't support picking the last index in case of encountering + // duplicate max values. + // CuDNN's API doc doesn't mention what happens in case duplicates are encountered, + // but based on testing, the results seem to indicate a "stable" implementation + // (i.e.) relative ordering is preserved which is the expected behavior when the + // attribute takes on the default value (most common use-case for this operator). + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} + std::unique_ptr CUDAExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2616,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } else if ("ConvTranspose" == node.OpType()) { not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside @@ -2637,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index c5736733beb1d..bd2be2eac2181 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -9,7 +9,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/cuda_execution_provider_info.h" #include "core/providers/cuda/cuda_graph.h" #include "core/providers/cuda/cuda_pch.h" @@ -251,7 +251,7 @@ class CUDAExecutionProvider : public IExecutionProvider { std::set, std::owner_less>> caches_to_update_on_destruction; // synchronizes access to PerThreadContextState members - OrtMutex mutex; + std::mutex mutex; }; // The execution provider maintains the PerThreadContexts in this structure. diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index dd03db94b631c..064b526e604bc 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -6,7 +6,7 @@ #include #include "core/common/common.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/cuda_pch.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 9d37a9775872f..054dd9f9da9f3 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -6,7 +6,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_fwd.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/cuda_stream_handle.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc index 640025c248187..7cd320a26d973 100644 --- a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc +++ b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc @@ -4,7 +4,7 @@ #include "core/providers/cuda/shared_inc/cudnn_fe_call.h" #include "core/providers/shared_library/provider_api.h" #include -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #include #endif #ifdef _WIN32 @@ -22,7 +22,7 @@ const char* CudaErrString(ERRTYPE) { ORT_NOT_IMPLEMENTED(); } -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #define CASE_ENUM_TO_STR_CUDNN_FE(x) \ case cudnn_frontend::error_code_t::x: \ return #x diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 71610634577ca..4dafbda409cd3 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -7,10 +7,6 @@ #include "cuda_common.h" namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} - bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; @@ -30,19 +26,25 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // copy from pinned or non-pinned CPU memory to GPU CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -69,7 +71,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking + // copy from GPU to pinned or non-pinned CPU memory. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } } else { @@ -77,6 +79,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } + + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.h b/onnxruntime/core/providers/cuda/gpu_data_transfer.h index 68846e68079f3..11e21e91936fc 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.h +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index cbde6da457fdb..112566e54bbba 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -412,7 +412,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (aligned_dimension <= GridDim::maxThreadsPerBlock) { BitonicTopK<<), stream>>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, - aligned_dimension, NumericLimits::Min(), NumericLimits::Max()); + aligned_dimension, NumericLimits::Lowest(), NumericLimits::Max()); } else if (K <= BT * 16 || 0 == sorted) { if (use_deterministic_compute) { static std::once_flag log_warning; @@ -425,19 +425,19 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (BT * 2 >= K || 0 == sorted) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 4 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 8 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } } else { auto input_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index cc76198dc3ae9..3129f519da2e5 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -457,7 +457,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected template Status Conv::ComputeInternal(OpKernelContext* context) const { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 484d66081018b..e4047a6af272e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -13,7 +13,7 @@ #include #endif -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" #include "core/providers/cpu/nn/conv_attributes.h" @@ -190,7 +190,7 @@ struct CudnnConvState { TensorShapeVector slice_axes; // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing - OrtMutex mutex; + std::mutex mutex; IAllocatorUniquePtr memory_for_cudnn_conv_results; ~CudnnConvState() { diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h index 10239d09041fe..bcee1bcb7e231 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -387,7 +387,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) template Status Conv::ComputeInternal(OpKernelContext* context) const { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index d4876e1714861..2972ae999adc4 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -450,7 +450,7 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna template Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context, dynamic_padding)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h index b46d41b887e41..aa1fe26ac97db 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -87,7 +87,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy } { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with // different batch_size diff --git a/onnxruntime/core/providers/cuda/nvtx_profile_context.h b/onnxruntime/core/providers/cuda/nvtx_profile_context.h index e2e3be07bd474..eb28f86becd20 100644 --- a/onnxruntime/core/providers/cuda/nvtx_profile_context.h +++ b/onnxruntime/core/providers/cuda/nvtx_profile_context.h @@ -7,7 +7,7 @@ #include #include -#include "core/platform/ort_mutex.h" +#include #ifdef ENABLE_NVTX_PROFILE @@ -25,14 +25,14 @@ class Context { // Return tag for the specified thread. // If the thread's tag doesn't exist, this function returns an empty string. std::string GetThreadTagOrDefault(const std::thread::id& thread_id) { - const std::lock_guard lock(mtx_); + const std::lock_guard lock(mtx_); return thread_tag_[thread_id]; } // Set tag for the specified thread. void SetThreadTag( const std::thread::id& thread_id, const std::string& tag) { - const std::lock_guard lock(mtx_); + const std::lock_guard lock(mtx_); thread_tag_[thread_id] = tag; } @@ -44,7 +44,7 @@ class Context { // map from thread's id to its human-readable tag. std::unordered_map thread_tag_; - OrtMutex mtx_; + std::mutex mtx_; }; } // namespace profile diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 860bea67dc719..4f8e6605ce151 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,17 +16,17 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, end, \ + begin, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -37,8 +37,13 @@ namespace cuda { name); #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ - REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) + +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -829,14 +834,13 @@ template std::unique_ptr ReduceCompute class ArgMax final : public ReduceKernel { public: - ArgMax(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMax(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MAX); @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel { template class ArgMin final : public ReduceKernel { public: - ArgMin(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMin(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MIN); diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index ed642754af3ba..f9433642f0857 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/framework/float16.h" @@ -120,7 +121,7 @@ constexpr int kNumBitsPerBitmaskElement = std::numeric_limits struct NumericLimits { - __inline__ __host__ __device__ static T Min() { + __inline__ __host__ __device__ static T Lowest() { return std::numeric_limits::lowest(); } __inline__ __host__ __device__ static T Max() { @@ -128,43 +129,18 @@ struct NumericLimits { } }; -template <> -struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; - } - __inline__ __host__ __device__ static half Max() { - return 65504.0; - } -}; - template <> struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; - } - __inline__ __host__ __device__ static half Max() { - return 65504.0; + __inline__ __host__ __device__ static half Lowest() { + return -65504.0f; } -}; -template <> -struct NumericLimits { - __inline__ __host__ __device__ static float Min() { - return -INFINITY; - } - __inline__ __host__ __device__ static float Max() { - return INFINITY; - } -}; - -template <> -struct NumericLimits { - __inline__ __host__ __device__ static double Min() { - return -HUGE_VAL; - } - __inline__ __host__ __device__ static double Max() { - return HUGE_VAL; + __inline__ __host__ __device__ static half Max() { +#ifdef CUDART_MAX_NORMAL_FP16 // defined in cuda 12.3 or later + return CUDART_MAX_NORMAL_FP16; +#else + return 65504.0f; +#endif } }; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h index a51d84a7efa59..2ce7bc0bf51fd 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #include #endif namespace onnxruntime { @@ -14,10 +14,12 @@ namespace onnxruntime { // Error handling // ----------------------------------------------------------------------- +#ifndef USE_CUDA_MINIMAL #define CUDNN_FE_CALL(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) #define CUDNN_FE_CALL_THROW(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu index 0dcc188d039a9..ce5a1ebf3faa5 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "nonzero_impl.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cu_inc/common.cuh" #include diff --git a/onnxruntime/core/providers/dml/CPPLINT.cfg b/onnxruntime/core/providers/dml/CPPLINT.cfg deleted file mode 100644 index 02d14c65cc861..0000000000000 --- a/onnxruntime/core/providers/dml/CPPLINT.cfg +++ /dev/null @@ -1 +0,0 @@ -filter=-whitespace/braces,-whitespace/parens,-whitespace/line_length,-whitespace/indent,-whitespace/newline diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index b1714a8220cd1..334a40b979bda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -41,23 +41,19 @@ namespace Dml D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, D3D12_RESOURCE_STATES initialState, - std::unique_ptr&& subAllocator - ) + std::unique_ptr&& subAllocator) : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - ) - ), - m_device(device), - m_heapProperties(heapProps), - m_heapFlags(heapFlags), - m_resourceFlags(resourceFlags), - m_initialState(initialState), - m_context(context), - m_subAllocator(std::move(subAllocator)) - { + OrtMemoryInfo( + "DML", + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))), + m_device(device), + m_heapProperties(heapProps), + m_heapFlags(heapFlags), + m_resourceFlags(resourceFlags), + m_initialState(initialState), + m_context(context), + m_subAllocator(std::move(subAllocator)) { } /*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index b22f0b2853e5d..f07b9540ff3fd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -24,7 +24,7 @@ namespace Dml OrtMemoryInfo( "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0) )) { m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 35a2c451a49a5..9f95818501dac 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -62,7 +62,8 @@ namespace Dml const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, gsl::make_span(®istry, 1), - kernel_type_str_resolver}; + kernel_type_str_resolver, + logger}; std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 6318b0d5e2865..b9b90d6bc17bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -54,7 +54,8 @@ namespace Dml const auto kernelLookup = onnxruntime::KernelLookup( providerType, gsl::make_span(®istry, 1), - kernelTypeStrResolver); + kernelTypeStrResolver, + logger); onnxruntime::GraphViewer graphViewer(graph); const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9c01df13741e1..826f48b5f7a68 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -74,7 +74,7 @@ namespace Dml bool enableGraphCapture, bool enableSyncSpinning, bool disableMemoryArena) : - IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) + IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) @@ -95,7 +95,7 @@ namespace Dml const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup); + return m_impl->GetCapability(graph, kernel_lookup, *GetLogger()); #else return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); #endif @@ -876,7 +876,8 @@ namespace Dml std::vector> ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. @@ -900,7 +901,7 @@ namespace Dml } // Get the list of nodes that should stay on the CPU - auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes); + auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger); for (size_t nodeIndex : toplogicalOrder) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index c20969250fe84..e7d859c5764de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -88,7 +88,8 @@ namespace Dml std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger ) const; uint32_t GetSupportedDeviceDataTypeMask() const; @@ -242,8 +243,8 @@ namespace Dml bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final { - return (srcDevice.Type() == OrtDevice::GPU) || - (dstDevice.Type() == OrtDevice::GPU); + return (srcDevice.Type() == OrtDevice::DML) || + (dstDevice.Type() == OrtDevice::DML); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg index bf14c49304415..7e6be3c6874d5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg @@ -1 +1 @@ -filter=-whitespace/comments,-readability/todo,-whitespace/end_of_line,-runtime/indentation_namespace +filter=-readability/todo diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 45ff25c4fdd90..389bdee3a365b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -28,9 +28,56 @@ class DmlOperatorCast : public DmlOperator castDesc.InputTensor = inputDescs.data(); castDesc.OutputTensor = outputDescs.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast(ONNX_NAMESPACE::TensorProto_DataType_BOOL)) + { + DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc }; - SetDmlOperatorDesc(opDesc, kernelInfo); + DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {}; + clipDesc.InputTensor = outputDescs.data(); + clipDesc.OutputTensor = outputDescs.data(); + clipDesc.Min.UInt8 = 0; + clipDesc.Max.UInt8 = 1; + + DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc }; + + std::vector opDescs = { &dmlCastDesc, &dmlClipDesc }; + + DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {}; + inputToCastEdge.GraphInputIndex = 0; + inputToCastEdge.ToNodeIndex = 0; + inputToCastEdge.ToNodeInputIndex = 0; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {}; + castToClipEdge.FromNodeIndex = 0; + castToClipEdge.FromNodeOutputIndex = 0; + castToClipEdge.ToNodeIndex = 1; + castToClipEdge.ToNodeInputIndex = 0; + + DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {}; + clipToOutputEdge.FromNodeIndex = 1; + clipToOutputEdge.FromNodeOutputIndex = 0; + clipToOutputEdge.GraphOutputIndex = 0; + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + operatorGraphDesc.inputEdgeCount = 1; + operatorGraphDesc.inputEdges = &inputToCastEdge; + + operatorGraphDesc.intermediateEdgeCount = 1; + operatorGraphDesc.intermediateEdges = &castToClipEdge; + + operatorGraphDesc.outputEdgeCount = 1; + operatorGraphDesc.outputEdges = &clipToOutputEdge; + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + else + { + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } } void Compute(const MLOperatorKernelContext& kernelContext) @@ -50,5 +97,6 @@ class DmlOperatorCast : public DmlOperator DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); +DML_OP_DEFINE_CREATION_FUNCTION(CastLike21, DmlOperatorCast); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index 9b7ad9aa9e088..f8710fd266c07 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -123,5 +123,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad21, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp index 88b827f61f0c9..ad7c77510d988 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp @@ -127,51 +127,51 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper DML_OPERATOR_DESC& desc = descs[i]; ActivationOperatorDescUnion& activationDesc = m_activationDescs[i]; desc.Desc = &activationDesc; - - if (activationName == AttrValue::ActivationRelu) + + if (CompareActivationName(activationName, AttrValue::ActivationRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_RELU; - } - else if (activationName == AttrValue::ActivationLeakyRelu) + } + else if (CompareActivationName(activationName, AttrValue::ActivationLeakyRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type); } - else if (activationName == AttrValue::ActivationThresholdedRelu) + else if (CompareActivationName(activationName, AttrValue::ActivationThresholdedRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU; activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationTanh) + } + else if (CompareActivationName(activationName, AttrValue::ActivationTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_TANH; - } - else if (activationName == AttrValue::ActivationScaledTanh) + } + else if (CompareActivationName(activationName, AttrValue::ActivationScaledTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH; activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type); activationDesc.scaledTanh.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationSigmoid) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSigmoid)) { desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID; - } - else if (activationName == AttrValue::ActivationSigmoidHard) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSigmoidHard)) { desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID; activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type); activationDesc.hardSigmoid.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationElu) + } + else if (CompareActivationName(activationName, AttrValue::ActivationElu)) { desc.Type = DML_OPERATOR_ACTIVATION_ELU; activationDesc.elu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationSoftsign) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSoftsign)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN; - } - else if (activationName == AttrValue::ActivationSoftplus) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSoftplus)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS; } @@ -182,6 +182,12 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper } } + bool CompareActivationName(std::string_view activationName, std::string_view attrValue) + { + auto comparer = [](char a, char b) {return std::tolower(a) == std::tolower(b);}; + return std::equal(activationName.begin(), activationName.end(), attrValue.begin(), attrValue.end(), comparer); + } + void Compute(const MLOperatorKernelContext& kernelContext) override { // Assume that enough GPU work has been queued up after the RNN operator that it is worth diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 0f15ebf342b3a..95d9644b4ca30 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -25,6 +25,41 @@ // The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap // the sign of every adjacent element. +// Here's a representation of what the graph looks like in DML, before getting fused together: +/* + Input CosCache PositionIds SinCache + | | | | + | | +--------+-----------+ | + Split | | | | + | | Gather Gather + +-------+ | | | + | | | | + | Identity----------+ | | + | | | | | + | | | | | + | --Split-- | | | + | \ / | +-----------------+ | + | \ / | | | + | \ / Mul | + | \ / | | + | X | | + | / \ | | + | / \ | | + | Join | | + | | | | + | | +---------------------------------------------------------+ + | | | | + | Mul | + | | | + | +-----+ +------+ + | | | + | Add + | | + +-------------+ | + | | + Join +*/ + namespace Dml { class DmlOperatorRotaryEmbedding : public DmlOperator @@ -56,25 +91,45 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); - const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; - // The last dimension of the data is the hidden size, so it must be divisible by the head size - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + uint32_t numHeads = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::NumHeads, 0)); + uint32_t rotaryEmbeddingDim = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::RotaryEmbeddingDim, 0)); - // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t hiddenSize = inputIs4D ? inputDataSizes[1] * inputDataSizes[3] : inputDataSizes.back(); + + const uint32_t headSize = numHeads == 0 + ? m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2 + : hiddenSize / numHeads; + + if (rotaryEmbeddingDim > 0) + { + ORT_ENFORCE(numHeads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } + else + { + rotaryEmbeddingDim = headSize; + } + + if (numHeads == 0) + { + numHeads = hiddenSize / headSize; + } + else if (inputIs4D) + { + ORT_ENFORCE(numHeads == inputDataSizes[1], "When the input has 4 dimensions, num_heads must be 0 or have the same value as the second dimension of the input"); + } + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; - if (sequenceLength > maxSequenceLength) + const bool isPackedBatching = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::IsPackedBatching, 0)) == 1; + if (!isPackedBatching && sequenceLength > maxSequenceLength) { ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); } @@ -84,64 +139,103 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const std::array inputOutputShape = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize}) + : std::array({batchSize, sequenceLength, numHeads, headSize}); + + const std::array splitInputOutputShape1 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}); + + const std::array splitInputOutputShape2 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize - rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, headSize - rotaryEmbeddingDim}); + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); - TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc splitInputOutputTensorDesc1 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape1); + TensorDesc splitInputOutputTensorDesc2 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape2); - if (inputIs4D) + // Split the input to perform the rotary embedding only on a subregion of the tensor if needed. The split inputs + // will be joined back together at the end. + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + std::array splitTensorDescs = { + splitInputOutputTensorDesc1.GetDmlDesc(), + splitInputOutputTensorDesc2.GetDmlDesc(), + }; + + DML_SPLIT_OPERATOR_DESC splitInputOperatorDesc{}; + DML_OPERATOR_DESC splitInputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) { - const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; - stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + splitInputOperatorDesc.InputTensor = &inputOutputDmlTensorDesc; + splitInputOperatorDesc.OutputCount = gsl::narrow_cast(splitTensorDescs.size()); + splitInputOperatorDesc.OutputTensors = splitTensorDescs.data(); + splitInputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + splitInputDmlOperatorDesc.Type = DML_OPERATOR_SPLIT; + splitInputDmlOperatorDesc.Desc = &splitInputOperatorDesc; } - const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); - - // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + // Copy the partial input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + const std::array partialInputOutputShape = {batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}; + TensorDesc partialStridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + TensorDesc partialInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + + if (inputIs4D) + { + const std::array partialInputOutputStrides = {rotaryEmbeddingDim * numHeads * sequenceLength, rotaryEmbeddingDim, sequenceLength * rotaryEmbeddingDim, 1}; + partialStridedInputOutputTensorDesc.SetStrides(partialInputOutputStrides); + } + + const DML_TENSOR_DESC partialStridedInputOutputDmlTensorDesc = partialStridedInputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC partialInputOutputDmlTensorDesc = partialInputOutputTensorDesc.GetDmlDesc(); + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; - copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &partialStridedInputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &partialInputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + const uint32_t halfRoraryEmbeddingDim = rotaryEmbeddingDim / 2; + // Split the input data into 2 equal parts - const std::vector inputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) - : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + const std::vector partialInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, rotaryEmbeddingDim / 2}); const std::vector splitInputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, rotaryEmbeddingDim / 2}); - TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc partialInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); + const DML_TENSOR_DESC partialInputDataDmlTensorDesc = partialInputDataTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - - TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; - DML_SPLIT_OPERATOR_DESC splitInputDesc{}; - splitInputDesc.InputTensor = &inputDataDmlTensorDesc; - splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); - splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - splitInputDesc.Axis = interleaved + DML_SPLIT_OPERATOR_DESC splitPartialInputDesc{}; + splitPartialInputDesc.InputTensor = &partialInputDataDmlTensorDesc; + splitPartialInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitPartialInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitPartialInputDesc.Axis = interleaved ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; - const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + const DML_OPERATOR_DESC splitPartialInputDmlDesc = {DML_OPERATOR_SPLIT, &splitPartialInputDesc}; // Swap the 2 halves and join them together - DML_JOIN_OPERATOR_DESC joinInputDesc{}; - joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; - joinInputDesc.Axis = splitInputDesc.Axis; - joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + DML_JOIN_OPERATOR_DESC joinPartialInputDesc{}; + joinPartialInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinPartialInputDesc.OutputTensor = &joinedDataDmlTensorDesc; + joinPartialInputDesc.Axis = splitPartialInputDesc.Axis; + joinPartialInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinPartialInputDmlDesc = {DML_OPERATOR_JOIN, &joinPartialInputDesc}; // We generate a sequence from 0 to sequenceLength and add the offset to it const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; @@ -177,7 +271,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; // Gather the cos/sin values based on the position ids - const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, rotaryEmbeddingDim / 2}; TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); @@ -191,9 +285,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data const std::vector reshapedCosSinShape = interleaved - ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); - TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + ? std::vector({batchSize, sequenceLength, 1, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, rotaryEmbeddingDim / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedCosSinShape); const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); // Create a vector that contains the sign values {-1, 1} @@ -224,7 +318,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const std::vector reshapedSignShape = interleaved ? std::vector({1, 1, 1, 1, 2}) : std::vector({1, 1, 1, 2, 1}); - TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedSignShape); const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; @@ -242,11 +336,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; - addDesc.ATensor = &inputOutputDmlTensorDesc; - addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; + addDesc.ATensor = &partialInputOutputDmlTensorDesc; + addDesc.BTensor = &partialInputOutputDmlTensorDesc; + addDesc.OutputTensor = &partialStridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + DML_JOIN_OPERATOR_DESC joinOutputOperatorDesc{}; + DML_OPERATOR_DESC joinOutputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) + { + joinOutputOperatorDesc.InputCount = gsl::narrow_cast(splitTensorDescs.size()); + joinOutputOperatorDesc.InputTensors = splitTensorDescs.data(); + joinOutputOperatorDesc.OutputTensor = &inputOutputDmlTensorDesc; + joinOutputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + joinOutputDmlOperatorDesc.Type = DML_OPERATOR_JOIN; + joinOutputDmlOperatorDesc.Desc = &joinOutputOperatorDesc; + } + // Construct the graph std::vector inputEdges; std::vector intermediateEdges; @@ -254,12 +360,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector opDescs = { ©InputDmlDesc, // Copy the input data to preseve the real input shape - &splitInputDmlDesc, // Split the input data + &splitPartialInputDmlDesc, // Split the input data &gatherCosSinDmlDesc, // Gather cos &gatherCosSinDmlDesc, // Gather sin &signRangeDmlDesc, // Generate the signs - &joinInputDmlDesc, // Join the split data + &joinPartialInputDmlDesc, // Join the split data &mulCosSinDmlDesc, // Multiply cos with the non-rotated data &mulCosSinDmlDesc, // Multiply sin with the rotated data &mulSignDmlDesc, // Multiply the sign with the rotated data @@ -269,12 +375,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator enum NodeIndex : uint32_t { copyInputOpIndex, - splitInputOpIndex, + splitPartialInputOpIndex, gatherCosOpIndex, gatherSinOpIndex, signRangeOpIndex, - joinInputOpIndex, + joinPartialInputOpIndex, mulCosOpIndex, mulSinOpIndex, mulSignOpIndex, @@ -285,6 +391,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator positionIdsAddOffsetOpIndex, }; + uint32_t splitInputOpIndex = positionIdsIsOffset ? positionIdsAddOffsetOpIndex + 1 : addOpIndex + 1; + uint32_t joinOutputOpIndex = splitInputOpIndex + 1; + if (positionIdsIsOffset) { opDescs.push_back(&positionIdsRangeDmlDesc); @@ -332,11 +441,32 @@ class DmlOperatorRotaryEmbedding : public DmlOperator inputEdges.push_back(positionIdsToGatherSinEdge); } - DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; - inputToCopyInputEdge.GraphInputIndex = inputDataIndex; - inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; - inputToCopyInputEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToCopyInputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + opDescs.push_back(&splitInputDmlOperatorDesc); + opDescs.push_back(&joinOutputDmlOperatorDesc); + + DML_INPUT_GRAPH_EDGE_DESC inputToSplitInputEdge = {}; + inputToSplitInputEdge.GraphInputIndex = inputDataIndex; + inputToSplitInputEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSplitInputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC partialInputToCopyInputEdge = {}; + partialInputToCopyInputEdge.FromNodeIndex = splitInputOpIndex; + partialInputToCopyInputEdge.FromNodeOutputIndex = 0; + partialInputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + partialInputToCopyInputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(partialInputToCopyInputEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + } DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; cosToGatherEdge.GraphInputIndex = cosCacheIndex; @@ -353,7 +483,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; inputToSplitEdge.FromNodeIndex = copyInputOpIndex; inputToSplitEdge.FromNodeOutputIndex = 0; - inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeIndex = splitPartialInputOpIndex; inputToSplitEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(inputToSplitEdge); @@ -365,16 +495,16 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(nonRotatedDataToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; - secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; - secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; secondHalfDataToJoinEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(secondHalfDataToJoinEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; - firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; - firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; firstHalfDataToJoinEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(firstHalfDataToJoinEdge); @@ -386,7 +516,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(cosToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; - rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeIndex = joinPartialInputOpIndex; rotatedDataToMulEdge.FromNodeOutputIndex = 0; rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; rotatedDataToMulEdge.ToNodeInputIndex = 0; @@ -427,11 +557,36 @@ class DmlOperatorRotaryEmbedding : public DmlOperator rotatedSinToAddEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(rotatedSinToAddEdge); - DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; - addToOutputEdge.FromNodeIndex = addOpIndex; - addToOutputEdge.FromNodeOutputIndex = 0; - addToOutputEdge.GraphOutputIndex = 0; - outputEdges.push_back(addToOutputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC addToJoinOutputEdge = {}; + addToJoinOutputEdge.FromNodeIndex = addOpIndex; + addToJoinOutputEdge.FromNodeOutputIndex = 0; + addToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + addToJoinOutputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(addToJoinOutputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC remainingInputToJoinOutputEdge = {}; + remainingInputToJoinOutputEdge.FromNodeIndex = splitInputOpIndex; + remainingInputToJoinOutputEdge.FromNodeOutputIndex = 1; + remainingInputToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + remainingInputToJoinOutputEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(remainingInputToJoinOutputEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC joinOutputToOutputEdge = {}; + joinOutputToOutputEdge.FromNodeIndex = joinOutputOpIndex; + joinOutputToOutputEdge.FromNodeOutputIndex = 0; + joinOutputToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(joinOutputToOutputEdge); + } + else + { + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + } MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 2375131cb34ea..b0b37d01370bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -365,6 +365,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad11); DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(Pad18); DML_OP_EXTERN_CREATION_FUNCTION(Pad19); +DML_OP_EXTERN_CREATION_FUNCTION(Pad21); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -445,6 +446,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeMatMul); DML_OP_EXTERN_CREATION_FUNCTION(Cast); DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); +DML_OP_EXTERN_CREATION_FUNCTION(CastLike21); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); DML_OP_EXTERN_CREATION_FUNCTION(TopK7); @@ -792,6 +794,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO( 21, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. {REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. @@ -804,6 +807,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, + {REG_INFO_VER( 21, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, #if DML_TARGET_VERSION >= 0x6400 {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, @@ -819,6 +823,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))}, + {REG_INFO( 21, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))}, {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, {REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, @@ -849,10 +854,12 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(21, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(21, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, @@ -1087,6 +1094,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 21, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, @@ -1102,6 +1110,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 21, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, @@ -1149,6 +1158,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, + {REG_INFO( 21, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, @@ -1162,6 +1172,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, + {REG_INFO( 21, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, MatMulNBits, typeNameListTwo, supportedTypeListMatMulNBits, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMatMulNBits)}, // Operators that need to alias an input with an output diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp index 375ee87bd42f1..a4d284df43a72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp @@ -125,8 +125,8 @@ namespace Dml // No chunks were able to accommodate the allocation - create a new chunk and return that instead - // At least double the capacity of the pool - const size_t newChunkSize = std::max({ m_totalCapacity, c_minChunkSize, sizeInBytes }); + // At least double the capacity of the pool, limit to c_maxChunkSize so DX12 does not reject size + const size_t newChunkSize = std::min(std::max({ m_totalCapacity, c_minChunkSize, sizeInBytes }), c_maxChunkSize); m_chunks.push_back(CreateChunk(m_device.Get(), newChunkSize)); m_totalCapacity += newChunkSize; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h index 1202ae9243921..0315b087519ba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h @@ -32,6 +32,7 @@ namespace Dml private: static constexpr size_t c_minChunkSize = 1024 * 1024; // 1MB static constexpr size_t c_allocationAlignment = 512; // In bytes; as per D3D12 requirement for buffers + static constexpr size_t c_maxChunkSize = 0xFFFF0000; // ~4 GiB limitation for DX12 CPU-visible resource // A suballoction from a chunk struct Allocation diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 0c5739554b800..3d23fb6206479 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -130,6 +130,8 @@ namespace AttrName static constexpr const char* UppercaseN = "N"; static constexpr const char* UppercaseK = "K"; static constexpr const char* MatMulNBitsBlockSize = "block_size"; + static constexpr const char* RotaryEmbeddingDim = "rotary_embedding_dim"; + static constexpr const char* IsPackedBatching = "is_packed_batching"; } // namespace AttrName diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 323fcc779d98d..c52e26dd321ab 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1648,6 +1648,7 @@ using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper; using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_GroupNorm21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper; @@ -1673,6 +1674,7 @@ using ShapeInferenceHelper_Flatten7 = FlattenHelper; using ShapeInferenceHelper_Flatten9 = FlattenHelper; using ShapeInferenceHelper_Flatten11 = FlattenHelper; using ShapeInferenceHelper_Flatten13 = FlattenHelper; +using ShapeInferenceHelper_Flatten21 = FlattenHelper; using ShapeInferenceHelper_Split7 = VersionedOpsetHelper; using ShapeInferenceHelper_Split11 = VersionedOpsetHelper; using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; @@ -1689,6 +1691,7 @@ using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad21 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; @@ -1837,6 +1840,7 @@ using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Identity21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; @@ -1865,6 +1869,7 @@ using ShapeInferenceHelper_Range = RangeHelper; using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_CastLike21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DmlFusedConv = ConvHelper; using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 26529c0d59dd6..b4d402a1d9e77 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -446,6 +446,15 @@ namespace OperatorHelper static const int sc_sinceVer_Reshape = 21; static const int sc_sinceVer_Cast = 21; static const int sc_sinceVer_Shape = 21; + static const int sc_sinceVer_Size = 21; + static const int sc_sinceVer_CastLike = 21; + static const int sc_sinceVer_ConstantOfShape = 21; + static const int sc_sinceVer_Flatten = 21; + static const int sc_sinceVer_Pad = 21; + static const int sc_sinceVer_Transpose = 21; + static const int sc_sinceVer_Identity = 21; + static const int sc_sinceVer_QLinearMatMul = 21; + static const int sc_sinceVer_GroupNorm = 21; } namespace MsftOperatorSet1 diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index ffda84921a3ee..c96f9cc1ff400 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -12,7 +12,7 @@ #include #endif // defined(DNNL_OPENMP) -#include "core/platform/ort_mutex.h" +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/dnnl/dnnl_execution_provider.h" @@ -356,7 +356,7 @@ Status DnnlExecutionProvider::Compile(const std::vector& fuse // lock each subgraph_primitive as multiple threads have shared memories { - std::unique_lock lock(subgraph_primitive->GetMutex()); + std::unique_lock lock(subgraph_primitive->GetMutex()); subgraph_primitive->Compile(inputs); std::unordered_map outputs; outputs.reserve(subgraph_num_outputs); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h index a7e49b54d4507..3bd12f1cf6f7e 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h @@ -4,7 +4,7 @@ #pragma once #include "dnnl_subgraph.h" #include "dnnl.hpp" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { namespace ort_dnnl { @@ -69,7 +69,7 @@ class DnnlSubgraphPrimitive { // If the input being a scalar affects the operator this function can be used to determine if the // original input from ORT was a scalar. bool IsScalar(const DnnlTensor& tensor); - OrtMutex& GetMutex() { return mutex_; } + std::mutex& GetMutex() { return mutex_; } // GetMemory in OrtFormat if the memory is not in the OrtFormat this will reorder the memory. // All memory will be moved to the dnnl_engine even if it is already in OrtFormat. @@ -125,7 +125,7 @@ class DnnlSubgraphPrimitive { dnnl::engine cpu_engine_; dnnl::engine gpu_engine_; - OrtMutex mutex_; + std::mutex mutex_; // for memory debug purpose std::vector> items_to_print_; diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index d2a72c3a38b03..7d8c5525667b9 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -66,14 +66,6 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, -#endif - }, - { - kTvmExecutionProvider, -#ifdef USE_TVM - true, -#else - false, #endif }, { diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index ebea041b80128..3809df2c82e4c 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -6,7 +6,7 @@ #include "core/providers/js/data_transfer.h" EM_ASYNC_JS(void, jsepDownload, (const void* src_data, void* dst_data, size_t bytes), { - await Module.jsepCopyAsync(src_data, dst_data, bytes); + await Module.jsepCopyAsync(Number(src_data), Number(dst_data), Number(bytes)); }); namespace onnxruntime { @@ -30,10 +30,10 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::GPU) { // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2), true); }, src_data, dst_data, bytes); } else { // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 1ff33f6d7b410..c1a8b373bed84 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -121,7 +121,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Not) class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 10, Clip); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, Clip); @@ -139,7 +140,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); @@ -150,7 +152,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); @@ -233,17 +236,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Res class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -273,10 +279,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 18, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); @@ -333,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, GatherND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, GatherND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); @@ -341,7 +353,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sli class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile); @@ -358,12 +371,14 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); @@ -389,6 +404,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 15, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ScatterND); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -439,7 +461,8 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), + KERNEL_CREATE_INFO(21, Cast), // activations KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), @@ -501,12 +524,14 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -515,13 +540,15 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -561,7 +588,8 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(16, Where), BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -591,10 +619,12 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -641,6 +671,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -660,7 +694,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -677,12 +712,14 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -706,6 +743,13 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -774,7 +818,7 @@ std::vector> JsExecutionProvider::GetCapabili candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc index 2402bb33ce9d0..f99e90bcb13f6 100644 --- a/onnxruntime/core/providers/js/js_export.cc +++ b/onnxruntime/core/providers/js/js_export.cc @@ -6,8 +6,8 @@ #include "core/framework/op_kernel.h" const void* JsepOutput(void* context, int index, const void* data) { - const uint32_t* data_offset = reinterpret_cast(data); - uint32_t dim = *data_offset++; + const uintptr_t* data_offset = reinterpret_cast(data); + uintptr_t dim = *data_offset++; size_t dim_size = static_cast(dim); std::vector dims(dim_size); for (size_t i = 0; i < dim_size; i++) { diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 7324b0d69474c..68d89c96d96f7 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -110,16 +110,17 @@ class JsKernel : public OpKernel { temp_data_size += sizeof(size_t) * 3; } } - uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); + uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); if (p_serialized_kernel_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context."); } - p_serialized_kernel_context[0] = reinterpret_cast(context); - p_serialized_kernel_context[1] = static_cast(context->InputCount()); - p_serialized_kernel_context[2] = static_cast(context->OutputCount()); - p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); - p_serialized_kernel_context[4] = static_cast(custom_data_size); + p_serialized_kernel_context[0] = reinterpret_cast(context); + p_serialized_kernel_context[1] = static_cast(context->InputCount()); + p_serialized_kernel_context[2] = static_cast(context->OutputCount()); + p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); + p_serialized_kernel_context[4] = static_cast(custom_data_size); + size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { const auto* input_ptr = context->Input(i); @@ -130,11 +131,11 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[index++] = 0; continue; } - p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); - p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } } @@ -199,9 +200,9 @@ class JsKernel : public OpKernel { return status; } - int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, - this, reinterpret_cast(p_serialized_kernel_context)); + intptr_t status_code = EM_ASM_INT( + { return Module.jsepRunKernel(Number($0), Number($1), Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, + this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/core/providers/js/operators/cast.cc b/onnxruntime/core/providers/js/operators/cast.cc index 9b6ac6d7e253b..f499d0627e032 100644 --- a/onnxruntime/core/providers/js/operators/cast.cc +++ b/onnxruntime/core/providers/js/operators/cast.cc @@ -49,10 +49,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 19, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 0357c2f02a7a2..b04df44954295 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -51,14 +51,14 @@ class ConvBase : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, - "dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "group" : $4, - "kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], - "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], - "strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], + "pads" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], + "strides" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : [] }), static_cast(conv_attrs_.auto_pad), JSEP_HEAP32_INDEX_START(dilations), diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index c51bf5ce9d4a6..5ff52e8fda4fa 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -48,8 +48,8 @@ class ConvTranspose : public JsKernel { "pads" : [ $5, $6 ], "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [], + "outputPadding" : $10 ? Array.from(HEAP32.subarray(Number($10), Number($11))) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray(Number($12), Number($13))) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -99,14 +99,14 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray(Number($2), (Number($2) >>> 0) + /* dialations_vec_size */ 2)), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)), + "kernelShape" : Array.from(HEAP32.subarray(Number($4), (Number($4) >>> 0) + /* kernel_shape_vec_size */ 2)), + "pads" : Array.from(HEAP32.subarray(Number($5), (Number($5) >>> 0) + /* pads_vec_size */ 4)), + "strides" : Array.from(HEAP32.subarray(Number($6), (Number($6) >>> 0) + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), - "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [], + "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc index 1aacae819e304..44a67cb15d958 100644 --- a/onnxruntime/core/providers/js/operators/flatten.cc +++ b/onnxruntime/core/providers/js/operators/flatten.cc @@ -36,10 +36,20 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Flatten); + ONNX_OPERATOR_KERNEL_EX( Flatten, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc index 485cd3da9b91b..e9c6f5c79294f 100644 --- a/onnxruntime/core/providers/js/operators/gather.cc +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -15,11 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -30,11 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -44,11 +36,7 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); diff --git a/onnxruntime/core/providers/js/operators/gather_nd.cc b/onnxruntime/core/providers/js/operators/gather_nd.cc new file mode 100644 index 0000000000000..ee69100cc658e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "gather_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherND, + kOnnxDomain, + 12, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherND, + kOnnxDomain, + 11, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather_nd.h b/onnxruntime/core/providers/js/operators/gather_nd.h new file mode 100644 index 0000000000000..cdf7a52630dad --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class GatherND : public JsKernel { + public: + GatherND(const OpKernelInfo& info) : JsKernel(info) { + int64_t batchDims = info.GetAttrOrDefault("batch_dims", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GatherND, ({ + "batch_dims" : Number($1), + }), + static_cast(batchDims)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/grid_sample.cc b/onnxruntime/core/providers/js/operators/grid_sample.cc new file mode 100644 index 0000000000000..84eb7df6c5bbe --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "grid_sample.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kMSInternalNHWCDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kOnnxDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/grid_sample.h b/onnxruntime/core/providers/js/operators/grid_sample.h new file mode 100644 index 0000000000000..352decf33dc20 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +template +class GridSample : public JsKernel { + public: + GridSample(const OpKernelInfo& info) : JsKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + std::string mode = info.GetAttrOrDefault("mode", "linear"); + std::string padding_mode = info.GetAttrOrDefault("padding_mode", "zeros"); + int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GridSample, ({ + "align_corners" : $1, + "mode" : UTF8ToString($2), + "padding_mode" : UTF8ToString($3), + "format" : $4 ? "NHWC" : "NCHW" + }), + static_cast(align_corners), mode.c_str(), + padding_mode.c_str(), static_cast(channels_last)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc index ef072bb1635dd..368d1b5101bdb 100644 --- a/onnxruntime/core/providers/js/operators/if.cc +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -44,9 +44,21 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, If); // opset-19 supports float8 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + ONNX_OPERATOR_KERNEL_EX(If, kOnnxDomain, - 19, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 83fee35481aa6..556fdf419212f 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -56,10 +56,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 3), Pad); -ONNX_OPERATOR_KERNEL_EX( +ONNX_OPERATOR_VERSIONED_KERNEL_EX( Pad, kOnnxDomain, 19, + 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +ONNX_OPERATOR_KERNEL_EX( + Pad, + kOnnxDomain, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index c18c7dd456dc2..f656462285bc4 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,7 +22,7 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(mode_), static_cast(value_), JSEP_HEAP32_INDEX_START(pads), diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 7df1e483f52a1..50efafac7d3e6 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -55,8 +55,10 @@ POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10) -POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) -POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 11, 18) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11, 18) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 19) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 19) POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 66bcde86020b6..32556eeaeefe4 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -3,22 +3,22 @@ #pragma once -#include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/pool_base.h" +#include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ - "format" : $13 ? "NHWC" : "NCHW", \ - "auto_pad" : $1, \ - "ceil_mode" : $2, \ - "count_include_pad" : $3, \ - "storage_order" : $4, \ - "dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \ - "kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \ - "pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \ - "strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \ +#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ + "format" : $13 ? "NHWC" : "NCHW", \ + "auto_pad" : $1, \ + "ceil_mode" : $2, \ + "count_include_pad" : $3, \ + "storage_order" : $4, \ + "dilations" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], \ + "kernel_shape" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], \ + "pads" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], \ + "strides" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [] \ }) #define POOL_ATTRIBUTES_PARAM_LIST \ diff --git a/onnxruntime/core/providers/js/operators/reduce.cc b/onnxruntime/core/providers/js/operators/reduce.cc index 2679cfed86124..98c329c1d9377 100644 --- a/onnxruntime/core/providers/js/operators/reduce.cc +++ b/onnxruntime/core/providers/js/operators/reduce.cc @@ -20,6 +20,16 @@ namespace js { // a new opset version update applies to Reduce* operators, we may need to add another macro like // REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type. // i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased. +#define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ReduceOp); #define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ ONNX_OPERATOR_KERNEL_EX( \ @@ -41,13 +51,15 @@ REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17); -REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 18); +REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMax, 18, 19); +REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 20); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17); -REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 18); +REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMin, 18, 19); +REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 20); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 937f1f990dc67..4ae558f9dfc00 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - JSEP_HEAP32_INDEX_START(axes), \ - JSEP_HEAP32_INDEX_END(axes)); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + JSEP_HEAP32_INDEX_START(axes), \ + JSEP_HEAP32_INDEX_END(axes)); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 134eb4bf5a7f4..3e8ccf40753c8 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.cc b/onnxruntime/core/providers/js/operators/scatter_nd.cc new file mode 100644 index 0000000000000..e9edb7f58fe5e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "scatter_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + ScatterND, + kOnnxDomain, + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 16, + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 13, + 15, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.h b/onnxruntime/core/providers/js/operators/scatter_nd.h new file mode 100644 index 0000000000000..8c81c62d71fe7 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace js { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +class ScatterND : public JsKernel { + public: + ScatterND(const OpKernelInfo& info) : JsKernel(info) { + std::string reduction = info.GetAttrOrDefault("reduction", "none"); + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees " + << "to be correct if indices are not duplicated."; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on webgpu when opset <= 13."); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(ScatterND, ({ + "reduction" : UTF8ToString($1), + }), + reduction.c_str()); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + Tensor* Y = context->Output(0, X_shape); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + return ComputeInternal(context); + } + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index daeffaa664741..f30e7bf01ec7b 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,9 +20,9 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}), + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [], + "ends" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : [], + "axes" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : []}), JSEP_HEAP32_INDEX_START(starts), JSEP_HEAP32_INDEX_END(starts), JSEP_HEAP32_INDEX_START(ends), diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 4fdbab00e739c..3f6cfcb8921f3 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,7 +49,7 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(axis_), static_cast(num_outputs_), JSEP_HEAP32_INDEX_START(split_sizes), diff --git a/onnxruntime/core/providers/js/operators/squeeze.cc b/onnxruntime/core/providers/js/operators/squeeze.cc index e858ade348cd4..521d0103d373f 100644 --- a/onnxruntime/core/providers/js/operators/squeeze.cc +++ b/onnxruntime/core/providers/js/operators/squeeze.cc @@ -10,7 +10,7 @@ namespace js { ONNX_OPERATOR_KERNEL_EX( Squeeze, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()) @@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 1), Squeeze); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Squeeze, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/transpose.cc b/onnxruntime/core/providers/js/operators/transpose.cc index 332bd35f2434c..136879b93b37f 100644 --- a/onnxruntime/core/providers/js/operators/transpose.cc +++ b/onnxruntime/core/providers/js/operators/transpose.cc @@ -15,10 +15,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", JsepSupportedDataTypes()), Transpose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + Transpose); + ONNX_OPERATOR_KERNEL_EX( Transpose, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()), diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index 7a945471c7701..f6b2b4faba850 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,7 +21,7 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [] }), JSEP_HEAP32_INDEX_START(perm), JSEP_HEAP32_INDEX_END(perm)); diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.cc b/onnxruntime/core/providers/js/operators/unsqueeze.cc index 1485e800e5e76..898deb827cccb 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.cc +++ b/onnxruntime/core/providers/js/operators/unsqueeze.cc @@ -10,7 +10,7 @@ namespace js { ONNX_OPERATOR_KERNEL_EX( Unsqueeze, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()) @@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 1), Unsqueeze); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Unsqueeze, kOnnxDomain, diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 94480c308b99f..77c5e18a5878e 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "gpu_data_transfer.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" + +// If you make change below, please also update onnxruntime/core/providers/rocm/gpu_data_transfer.cc namespace onnxruntime { + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -23,17 +27,24 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const if (src_device.Type() == OrtDevice::GPU) { // Copy only if the two addresses are different. if (dst_data != src_data) { - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -49,23 +60,28 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copy from pinned memory to GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + if (src_device.Type() == OrtDevice::CPU) { + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); } } else if (src_device.Type() == OrtDevice::GPU) { - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index c9db31e8744a7..3d9ae2bf7e6ff 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -51,7 +51,7 @@ void* MIGraphXExternalAllocator::Alloc(size_t size) { void MIGraphXExternalAllocator::Free(void* p) { free_(p); - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); @@ -62,7 +62,7 @@ void MIGraphXExternalAllocator::Free(void* p) { void* MIGraphXExternalAllocator::Reserve(size_t size) { void* p = Alloc(size); if (!p) return nullptr; - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); ORT_ENFORCE(reserved_.find(p) == reserved_.end()); reserved_.insert(p); return p; diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 64da844e8c714..c8c935eba44ab 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -5,7 +5,7 @@ #include #include "core/framework/allocator.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -42,7 +42,7 @@ class MIGraphXExternalAllocator : public MIGraphXAllocator { void* Reserve(size_t size) override; private: - mutable OrtMutex lock_; + mutable std::mutex lock_; ExternalAlloc alloc_; ExternalFree free_; ExternalEmptyCache empty_cache_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index f6a95cebf34b5..6d514e01aea96 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -3,6 +3,7 @@ #pragma once #include "migraphx_inc.h" +#include "core/common/common.h" namespace onnxruntime { @@ -16,5 +17,6 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 90dfa49c73c9a..9017b36a0f087 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -49,6 +49,8 @@ class Memcpy final : public OpKernel { const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); if (!gpu_data_transfer) return Status(common::ONNXRUNTIME, common::EP_FAIL, "gpu data transfer is missing in Migraphx EP."); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. return gpu_data_transfer->CopyTensorAsync(*X, *Y, *(ctx->GetComputeStream())); } }; @@ -800,6 +802,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ATen", "AveragePool", "BatchNormalization", + "BiasGelu", "Cast", "Ceil", "Celu", @@ -824,17 +827,21 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Exp", "Expand", "EyeLike", + "FastGelu", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "GatherND", + "Gelu", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "GreaterOrEqual", + "GroupNormalization", + "GroupQueryAttention", "HardSigmoid", "HardSwish", "Identity", @@ -842,6 +849,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ImageScaler", "InstanceNormalization", "IsNan", + "LayerNormalization", "LeakyRelu", "Less", "LessOrEqual", @@ -853,6 +861,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "LSTM", "MatMul", "MatMulInteger", + "MatMulNBits", "Max", "MaxPool", "Mean", @@ -861,6 +870,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Mul", "Multinomial", "Neg", + "NegativeLogLikelihoodLoss", "NonMaxSuppression", "NonZero", "Not", @@ -904,10 +914,13 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Shape", "Sigmoid", "Sign", + "SimplifiedLayerNormalization", "Sin", "Sinh", + "SkipSimplifiedLayerNormalization", "Slice", "Softmax", + "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "SpaceToDepth", @@ -1019,15 +1032,6 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v return result; } - // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, - // so if a model contain any of these operators, fall back to CPU - std::unordered_set vec_ops = {"SoftmaxCrossEntropyLoss"}; - if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { - return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); - })) { - return result; - } - auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); // check whether a subgrap should fallback to CPU @@ -1153,7 +1157,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!no_input_shape) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; + LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops @@ -1294,7 +1298,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // re-compile the program if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; #ifndef ENABLE_TRAINING_CORE #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); @@ -1420,7 +1424,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& { // lock to avoid race condition - std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); @@ -1438,7 +1442,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector ort_shape{res_lens.begin(), res_lens.end()}; auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + gpu_res.data(), + res_shape.bytes(), + hipMemcpyDeviceToDevice, + static_cast(rocm_stream))); } } }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 21679d1f6f151..91b6a4741b55e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -5,7 +5,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_inc.h" @@ -40,7 +40,7 @@ struct MIGraphXFuncState { migraphx::onnx_options options; migraphx::target t{}; std::unordered_map input_name_indexes; - OrtMutex* mgx_mu_ptr = nullptr; + std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; bool int8_enable = false; @@ -101,7 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::string load_compiled_path_; bool dump_model_ops_ = false; migraphx::target t_; - OrtMutex mgx_mu_; + std::mutex mgx_mu_; hipStream_t stream_ = nullptr; bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 03a7c1607e3ad..85b0aff87a436 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -6,8 +6,6 @@ #include "migraphx_inc.h" #include "migraphx_call.h" -#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) - namespace onnxruntime { void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 12416ea0c121b..e4bee6f959a01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -32,8 +32,16 @@ namespace nnapi { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, gsl::span nnapi_target_devices, - TargetDeviceOption target_device_option) - : nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) { + TargetDeviceOption target_device_option, + const logging::Logger& logger) + : nnapi_(nnapi_handle), + graph_viewer_(graph_viewer), + nnapi_model_{std::make_unique(nnapi_handle)}, + shaper_{graph_viewer}, + nnapi_target_devices_(nnapi_target_devices), + target_device_option_(target_device_option), + nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)), + logger_(logger) { nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_; } @@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index b2118150dd304..4db335afa98b0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -14,7 +14,9 @@ struct NnApi; namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; enum class DataLayout; class NodeUnit; @@ -31,7 +33,8 @@ class ModelBuilder { using Shape = Shaper::Shape; ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, - gsl::span nnapi_target_devices, TargetDeviceOption target_device_option); + gsl::span nnapi_target_devices, TargetDeviceOption target_device_option, + const logging::Logger& logger); common::Status Compile(std::unique_ptr& model); @@ -173,6 +176,9 @@ class ModelBuilder { // <1,1> <1,2> <1,3> InlinedVector> operations_recorder_; #endif + + const logging::Logger& logger_; + // Convert the ONNX model to ANeuralNetworksModel common::Status Prepare(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h index 3ff28d52e470f..643209fbe72b0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h @@ -6,7 +6,7 @@ #include #include "builders/shaper.h" -#include "core/platform/ort_mutex.h" +#include #include "nnapi_lib/NeuralNetworksWrapper.h" struct NnApi; @@ -98,7 +98,7 @@ class Model { void SetDynamicOutputBufferSize(size_t size) { dynamic_output_buffer_size_ = size; } // Mutex for exclusive lock to this model object - OrtMutex& GetMutex() { return mutex_; } + std::mutex& GetMutex() { return mutex_; } // If the given output is a scalar output // Since NNAPI does not support tensor with empty shape (scalar), we use {1} tensor for scalar in NNAPI @@ -130,7 +130,7 @@ class Model { // This is map is to lookup the nnapi output from the onnx output std::unordered_map onnx_to_nnapi_output_map_; - OrtMutex mutex_; + std::mutex mutex_; void AddInput(const std::string& name, const android::nn::wrapper::OperandType& operand_type); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 4d2888222ff0f..f92c9592742d5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) @@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL; #endif }(); - LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; + LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; const nnapi::OpSupportCheckParams params{ android_feature_level, @@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) { - LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" + LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" << params.android_feature_level << "] is lower than minimal supported NNAPI API feature level [" << ORT_NNAPI_MIN_API_LEVEL @@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view node_unit_supported_result[node_unit] = supported; } - LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + LOGS(logger, VERBOSE) << "Node supported: [" << supported << "] Operator type: [" << node.OpType() << "] index: [" << node.Index() << "] name: [" << node.Name() @@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // If the graph is partitioned in multiple subgraphs, and this may impact performance, // we want to give users a summary message at warning level. if (num_of_partitions > 1) { - LOGS_DEFAULT(WARNING) << summary_msg; + LOGS(logger, WARNING) << summary_msg; } else { - LOGS_DEFAULT(INFO) << summary_msg; + LOGS(logger, INFO) << summary_msg; } return result; @@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context, common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; + const auto& logger = *GetLogger(); + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_); + nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger); builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); @@ -380,7 +383,7 @@ common::Status NnapiExecutionProvider::Compile(const std::vector execution; - std::unique_lock lock(model->GetMutex()); + std::unique_lock lock(model->GetMutex()); ORT_RETURN_IF_ERROR(model->PrepareForExecution(execution)); ORT_RETURN_IF_ERROR(execution->SetInputBuffers(inputs)); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 4fca4037301fb..a0bcf953938d9 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -25,6 +25,11 @@ GlobalContext& BackendManager::GetGlobalContext() { return global_context_; } +ov::CompiledModel& BackendManager::GetOVCompiledModel() { + ov::CompiledModel& ov_ptr = concrete_backend_->GetOVCompiledModel(); + return (ov_ptr); +} + BackendManager::BackendManager(const GlobalContext& global_context, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, @@ -35,7 +40,7 @@ BackendManager::BackendManager(const GlobalContext& global_context, openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + std::to_string(global_context_.OpenVINO_Version.at(1)); if (ep_ctx_handle_.CheckForOVEPCtxNode(subgraph, openvino_sdk_version_)) { - if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph) != Status::OK()) + if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph, global_context_.ep_context_embed_mode) != Status::OK()) ORT_THROW("Import blob from model failed"); } @@ -65,7 +70,10 @@ BackendManager::BackendManager(const GlobalContext& global_context, i++; } subgraph_context_.subgraph_name = fused_node.Name(); - auto model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); + std::unique_ptr model_proto; + if (!ep_ctx_handle_.IsValidOVEPCtxGraph()) { + model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); + } std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; if (ModelHasSymbolicInputDims(subgraph)) { diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index b9ff7a72372b3..5ec462afd9d01 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -30,6 +30,7 @@ class BackendManager { GlobalContext& GetGlobalContext(); Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + ov::CompiledModel& GetOVCompiledModel(); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index f772b9c3b0478..b97736f2e124d 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -39,7 +39,7 @@ struct static_cast_int64 { int64_t operator()(const T1& x) const { return static_cast(x); } }; -std::shared_ptr +std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, std::map>& const_outputs_map) { if (IsCILogEnabled()) { @@ -47,13 +47,13 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } const std::string model = model_proto.SerializeAsString(); try { - auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); + auto ov_model = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); // Check for Constant Folding - if (!global_context.is_wholly_supported_graph) { + if ((global_context.device_type != "NPU") && !global_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; - pass_const_obj.run_on_model(cnn_network); - auto& results = const_cast(cnn_network.get()->get_results()); + pass_const_obj.run_on_model(ov_model); + auto& results = const_cast(ov_model.get()->get_results()); size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { @@ -67,12 +67,12 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } #ifndef NDEBUG if (IsDebugEnabled()) { - std::string name = cnn_network->get_friendly_name(); + std::string name = ov_model->get_friendly_name(); ov::pass::Serialize serializer(name + ".xml", name + ".bin"); - serializer.run_on_model(cnn_network); + serializer.run_on_model(ov_model); } #endif - return cnn_network; + return ov_model; } catch (std::string const& msg) { ORT_THROW(msg); } diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 9e65770da7d23..9d58e1ca73abb 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -60,7 +60,7 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); -std::shared_ptr +std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, std::map>& const_outputs_map); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 71a02f076c8cc..435ca83ff69d4 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -48,6 +48,16 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Set the inference_num_threads property of the CPU SetNumThreads(device_config); + auto npuw_status = + std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { + return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && + (pair.second.as() == "YES"); + }); + + if (npuw_status) { + LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; + } + try { std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; @@ -81,9 +91,9 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config, global_context_.ep_context_embed_mode, subgraph_context_.subgraph_name); - ie_cnn_network_ = exe_network_.Get().get_runtime_model(); } else if (global_context_.export_ep_ctx_blob && - hw_target.find("NPU") != std::string::npos) { + hw_target.find("NPU") != std::string::npos && + !global_context_.has_external_weights) { std::shared_ptr ov_model; { const std::string model = model_proto->SerializeAsString(); @@ -93,7 +103,8 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr ov_model = global_context_.ie_core.Get().read_model(model, ov::Tensor()); } exe_network_ = OVExeNetwork(global_context_.ie_core.Get().compile_model(ov_model, hw_target, device_config)); - } else if ((!subgraph_context_.has_dynamic_input_shape) && + } else if (!global_context_.has_external_weights && + (!subgraph_context_.has_dynamic_input_shape) && ((hw_target.find("AUTO") == std::string::npos) || (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) > 2))) { // Optimized OV compile_model API is supported with AUTO from version 2024.3 and above @@ -104,22 +115,22 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config, subgraph_context_.subgraph_name); } else { // For all other types use ov::Model Type - ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_); + auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_); exe_network_ = global_context_.ie_core.CompileModel( - ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } #endif } else { // Full graph is not supported - ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_); + auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_); exe_network_ = global_context_.ie_core.CompileModel( - ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const char* msg) { ORT_THROW(msg); } - - inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, 1)); + int num_infer_req = (global_context_.num_of_threads > 0) ? global_context_.num_of_threads : 1; + inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req)); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -143,8 +154,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_config.emplace(ov::hint::inference_precision("f32")); } if (global_context_.precision_str.find("ACCURACY") != std::string::npos && - global_context_.device_type == "GPU") { - if (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) >= 1) { + global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.OpenVINO_Version.at(0) >= 2024) { device_config.emplace(ov::hint::inference_precision(ov::element::undefined)); device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); } else { @@ -172,12 +183,110 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); } device_config.emplace(ov::device::properties("NPU", device_property)); -#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3) +#if (((OPENVINO_VERSION_MAJOR == 2024) && (OPENVINO_VERSION_MINOR > 3)) || (OPENVINO_VERSION_MAJOR > 2024)) if (global_context_.export_ep_ctx_blob) { global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true)); } #endif } + + if (!global_context_.load_config.empty()) { + const std::map& target_config = global_context_.load_config; + + if (global_context_.device_type.find("NPU") != std::string::npos) { + auto npuw_config = target_config.at("NPU"); + + // Check if "NPU_USE_NPUW" exists and is set to "YES" + auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW"); + if (npu_use_npuw_it != npuw_config.end() && + npu_use_npuw_it->second.is() && + npu_use_npuw_it->second.as() == "YES") { + // Only add NPUW-related keys if NPU_USE_NPUW is "YES" + for (const auto& [key, value] : npuw_config) { + if (key.find("NPUW") != std::string::npos) { + if (!value.is()) { + LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key; + continue; + } + device_config[key] = value; + } + } + } else { + // Check if there are any "NPUW" keys and log a warning + if (std::any_of(npuw_config.begin(), npuw_config.end(), + [&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) { + LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'."; + } + } + } + + // Parse device types like "AUTO:CPU,GPU" and extract individual devices + auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { + std::vector devices; + auto delimiter_pos = device_type.find(':'); + if (delimiter_pos != std::string::npos) { + std::stringstream str_stream(device_type.substr(delimiter_pos + 1)); + std::string device; + while (std::getline(str_stream, device, ',')) { + devices.emplace_back(device); + } + } else { + devices.emplace_back(device_type); + } + return devices; + }; + + // Check if a property is supported and mutable + auto is_supported_and_mutable = [&](const std::string& key, + const std::vector& supported_config) -> bool { + auto it = std::find_if(supported_config.begin(), supported_config.end(), [&](const ov::PropertyName& property) { + return property == key && property.is_mutable(); + }); + return it != supported_config.end(); + }; + + // Set properties if they are valid, else log a warning if the property is missing or immutable by skipping the same + auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, + const std::vector& supported_properties) { + for (const auto& [key, value] : config_options) { + if (key.find("NPUW") != std::string::npos) { + continue; + } + if (is_supported_and_mutable(key, supported_properties)) { + global_context_.ie_core.Get().set_property(device, ov::AnyMap{{key, value}}); + } else { + LOGS_DEFAULT(WARNING) << "WARNING: Property \"" << key + << "\" is either unsupported in current OpenVINO version" + << " or property is immutable for target device \"" + << device << "\". Skipping setting this property."; + } + } + }; + + // Check if the device type is AUTO, HETERO, or MULTI + if (global_context_.device_type.find("AUTO") == 0 || + global_context_.device_type.find("HETERO") == 0 || + global_context_.device_type.find("MULTI") == 0) { + // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) + auto individual_devices = parse_individual_devices(global_context_.device_type); + // Set properties only for individual devices (e.g., "CPU", "GPU") + for (const std::string& device : individual_devices) { + if (target_config.count(device)) { + // Get supported properties for each individual device + auto device_properties = global_context_.ie_core.Get().get_property(device, ov::supported_properties); + // Set properties for the device + set_target_properties(device, target_config.at(device), device_properties); + } + } + } else { + if (target_config.count(global_context_.device_type)) { + auto supported_properties = global_context_.ie_core.Get().get_property(global_context_.device_type, + ov::supported_properties); + set_target_properties(global_context_.device_type, + target_config.at(global_context_.device_type), supported_properties); + } + } + } } void BasicBackend::EnableCaching(ov::AnyMap& device_config) { @@ -275,7 +384,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque input_tensor_shape[tensor_iter] = *i; tensor_iter += 1; } - auto input = graph_input_info.at(input_idx); + const auto& input = graph_input_info.at(input_idx); OVTensorPtr tensor_ptr; // avoid input copies on the CPU device if (global_context_.device_type.find("CPU") != std::string::npos) { @@ -308,7 +417,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if ((it == ort_ov_tensor_map.end()) || (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { ov_tensor_data_t ov_tensor_data; - auto input = graph_input_info.at(input_idx); + const auto& input = graph_input_info.at(input_idx); ov_tensor_data.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), const_cast(tensor.GetTensorRawData())); @@ -316,7 +425,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(input_name, ov_tensor_data.tensor_ptr); + infer_request->SetTensor(std::move(input_name), ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } @@ -354,14 +463,14 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if ((it == ort_ov_tensor_map.end()) || (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { ov_tensor_data_t ov_tensor_data; - auto output = graph_output_info.at(output_idx); + const auto& output = graph_output_info.at(output_idx); ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), const_cast(tensor.GetTensorRawData())); ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr); + infer_request->SetTensor(std::move(output_name), ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } @@ -593,7 +702,6 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { // Requesting for an idle infer_request from a pool of infer_requests_ OVInferRequestPtr infer_request; infer_request = inferRequestsQueue_->getIdleRequest(); - #ifdef IO_BUFFER_ENABLED if ((global_context_.device_type.find("GPU") != std::string::npos) && (global_context_.context != nullptr) && global_context_.is_wholly_supported_graph) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 12502a1d83c5d..3fcf6e4384d52 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -58,7 +58,6 @@ class BasicBackend : public IBackend { GlobalContext& global_context_; SubGraphContext subgraph_context_; mutable std::mutex compute_lock_; - std::shared_ptr ie_cnn_network_; OVExeNetwork exe_network_; std::map> const_outputs_map_; std::unique_ptr inferRequestsQueue_; diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 598e985676f8d..4f970bc7bc287 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include #include "core/providers/openvino/ov_interface.h" @@ -15,18 +16,19 @@ namespace openvino_ep { struct GlobalContext { OVCore ie_core; bool is_wholly_supported_graph = false; - bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; bool disable_dynamic_shapes = false; - bool ep_context_embed_mode = true; + bool ep_context_embed_mode = false; bool export_ep_ctx_blob = false; bool enable_qdq_optimizer = false; bool disable_cpu_fallback = false; + bool has_external_weights = false; size_t num_of_threads; std::string device_type; std::string precision_str; std::string model_precision; std::string cache_dir; + std::map load_config; std::string model_priority = "DEFAULT"; int num_streams; std::vector deviceAvailableList = {true, true, true, true, true, true, true, true}; diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index ee9486a62ea37..6d159db3b390d 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -21,7 +21,8 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const { - auto model_build = graph_viewer.CreateModel(logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model_build = graph_viewer.CreateModel(logger, metadata); auto& graph_build = model_build->MainGraph(); // Get graph inputs and outputs @@ -94,17 +95,29 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, return Status::OK(); } -Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) { +Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer, bool& ep_context_embed_mode) { auto node = graph_viewer.GetNode(0); auto& attrs = node->GetAttributes(); ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0); - model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s()); + + ep_cache_context_attribute_ = &attrs.at(EP_CACHE_CONTEXT); + + ep_context_embed_mode = static_cast(attrs.at(EMBED_MODE).i()); LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; is_valid_ep_ctx_graph_ = true; return Status::OK(); } +const std::string& EPCtxHandler::GetModelBlobStream() const { + static std::string empty; + if (ep_cache_context_attribute_ != nullptr) { + return ep_cache_context_attribute_->s(); + } else { + return empty; + } +} + bool EPCtxHandler::CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const { for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) { auto node = graph_viewer.GetNode(i); diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index c631d011d02b1..caab33b7db775 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -23,21 +23,21 @@ static const char SOURCE[] = "source"; class EPCtxHandler { public: EPCtxHandler() = default; - EPCtxHandler(const EPCtxHandler&) = default; + EPCtxHandler(const EPCtxHandler&) = delete; Status ExportEPCtxModel(const GraphViewer& graph_viewer, const std::string& graph_name, const logging::Logger& logger, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const; - Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer); + Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer, bool& ep_context_embed_mode); bool CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const; bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; } - [[nodiscard]] const std::shared_ptr GetModelBlobStream() const { return model_stream_; } + const std::string& GetModelBlobStream() const; private: bool is_valid_ep_ctx_graph_{false}; - std::shared_ptr model_stream_; + const onnx::AttributeProto* ep_cache_context_attribute_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 08144651319cf..72a188108adef 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -2,13 +2,16 @@ // Licensed under the MIT License #include #include - +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/contexts.h" #include "core/providers/openvino/backend_manager.h" #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "openvino/core/version.hpp" #ifdef USE_OVEP_NPU_MEMORY #include "core/providers/openvino/ov_allocator.h" @@ -25,8 +28,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv global_context_ = std::make_unique(); global_context_->device_type = info.device_type_; global_context_->precision_str = info.precision_; - global_context_->enable_npu_fast_compile = info.enable_npu_fast_compile_; global_context_->cache_dir = info.cache_dir_; + global_context_->load_config = info.load_config_; global_context_->model_priority = info.model_priority_; global_context_->num_streams = info.num_streams_; global_context_->context = info.context_; @@ -124,6 +127,7 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, result = obj.Execute(); global_context_->is_wholly_supported_graph = obj.IsWhollySupportedGraph(); + global_context_->has_external_weights = obj.HasExternalWeights(); return result; } @@ -149,7 +153,7 @@ common::Status OpenVINOExecutionProvider::Compile( graph_body_viewer, *GetLogger(), ep_ctx_handle_); - + backend_manager_ = backend_manager; compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); @@ -185,16 +189,57 @@ common::Status OpenVINOExecutionProvider::Compile( #ifdef USE_OVEP_NPU_MEMORY std::vector OpenVINOExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo npu_allocator_info{ - [this](OrtDevice::DeviceId device_id) { - return std::make_unique(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU); - }, - 0, - }; - - // fill in allocator - return std::vector{CreateAllocator(npu_allocator_info)}; + if (global_context_->device_type.find("NPU") != std::string::npos) { + AllocatorCreationInfo npu_allocator_info{ + [this](OrtDevice::DeviceId device_id) { + return std::make_unique( + global_context_->ie_core.Get(), + OrtDevice::NPU, + device_id, + OpenVINO_RT_NPU); + }, + 0, + }; + + // fill in allocator + return std::vector{CreateAllocator(npu_allocator_info)}; + } else { + return std::vector{}; + } } #endif +common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + std::string workload_type = ""; + // Ensure the number of keys and values match + if (keys.size() != values.size()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Mismatched keys and values sizes."); + } + + for (size_t i = 0; i < keys.size(); ++i) { + std::string key = keys[i]; + std::string value = values[i]; + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Efficient") { + workload_type = "EFFICIENT"; + } else if (value == "Default") { + workload_type = "DEFAULT"; + } else { + LOGS_DEFAULT(WARNING) << "Unknown workload_type - ignoring " << key << "/" << value; + LOGS_DEFAULT(WARNING) << "Supported types are 'Efficient' and 'Default' \n"; + } + if (workload_type != "") { + LOGS_DEFAULT(INFO) << "SetEpDynamicOptions - modifying: " << key << "/" << value; + ov::CompiledModel& ov_compiled_model = backend_manager_->GetOVCompiledModel(); + ov_compiled_model.set_property(ov::workload_type(workload_type)); + } + } else { + // Handle unknown options + LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; + } + } + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 8b1c62c607f6e..d5c22a4e2a9e4 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -79,8 +79,8 @@ static std::vector parseDevices(const std::string& device_string, struct OpenVINOExecutionProviderInfo { std::string device_type_{""}; std::string precision_{""}; - bool enable_npu_fast_compile_{false}; size_t num_of_threads_{0}; + std::map load_config_{}; std::string cache_dir_{""}; std::string model_priority_{""}; int num_streams_{1}; @@ -90,20 +90,22 @@ struct OpenVINOExecutionProviderInfo { bool export_ep_ctx_blob_{false}; bool enable_qdq_optimizer_{false}; bool disable_cpu_fallback_{false}; - bool so_epctx_embed_mode_{true}; + bool so_epctx_embed_mode_{false}; OpenVINOExecutionProviderInfo() = delete; - explicit OpenVINOExecutionProviderInfo(const std::string& dev_type, const std::string& precision, - bool enable_npu_fast_compile, size_t num_of_threads, - const std::string& cache_dir, const std::string& model_priority, - int num_streams, void* context, bool enable_opencl_throttling, + explicit OpenVINOExecutionProviderInfo(std::string dev_type, const std::string& precision, + size_t num_of_threads, + const std::map& load_config, + const std::string& cache_dir, + const std::string& model_priority, int num_streams, + void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, bool export_ep_ctx_blob, bool enable_qdq_optimizer, bool disable_cpu_fallback, bool so_epctx_embed_mode) : precision_(std::move(precision)), - enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), + load_config_(std::move(load_config)), cache_dir_(std::move(cache_dir)), model_priority_(std::move(model_priority)), num_streams_(num_streams), @@ -157,7 +159,7 @@ struct OpenVINOExecutionProviderInfo { device_type_ = std::move(dev_type); } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { std::vector devices = parseDevices(dev_type, available_devices); - device_type_ = dev_type; + device_type_ = std::move(dev_type); } else { ORT_THROW("Invalid device string: " + dev_type); } @@ -186,6 +188,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; + Status SetEpDynamicOptions(gsl::span /*keys*/, + gsl::span /*values*/) override; + const void* GetExecutionHandle() const noexcept override { return nullptr; } @@ -194,6 +199,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { #endif private: std::unique_ptr global_context_; + std::shared_ptr backend_manager_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; }; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 077ecc717502f..5855cb594a08e 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -1,65 +1,82 @@ // Copyright (C) Intel Corporation // Licensed under the MIT License +#include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "nlohmann/json.hpp" namespace onnxruntime { struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(const char* device_type, const char* precision, - bool enable_npu_fast_compile, size_t num_of_threads, - const char* cache_dir, const char* model_priority, - int num_streams, void* context, + OpenVINOProviderFactory(const std::string& device_type, const std::string& precision, + size_t num_of_threads, + const std::map& load_config, const std::string& cache_dir, + const std::string& model_priority, int num_streams, void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, - bool export_ep_ctx_blob, bool enable_qdq_optimizer, - bool disable_cpu_fallback, - bool so_epctx_embed_mode) - : precision_(precision), - enable_npu_fast_compile_(enable_npu_fast_compile), + bool enable_qdq_optimizer, const ConfigOptions& config_options) + : device_type_(device_type), + precision_(precision), num_of_threads_(num_of_threads), + load_config_(load_config), + cache_dir_(cache_dir), model_priority_(model_priority), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), disable_dynamic_shapes_(disable_dynamic_shapes), - export_ep_ctx_blob_(export_ep_ctx_blob), enable_qdq_optimizer_(enable_qdq_optimizer), - disable_cpu_fallback_(disable_cpu_fallback), - so_epctx_embed_mode_(so_epctx_embed_mode) { - device_type_ = (device_type == nullptr) ? "" : device_type; - cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; - } + config_options_(config_options) {} - ~OpenVINOProviderFactory() override { - } + ~OpenVINOProviderFactory() override {} std::unique_ptr CreateProvider() override; private: std::string device_type_; std::string precision_; - bool enable_npu_fast_compile_; size_t num_of_threads_; + const std::map load_config_; std::string cache_dir_; std::string model_priority_; int num_streams_; void* context_; bool enable_opencl_throttling_; bool disable_dynamic_shapes_; - bool export_ep_ctx_blob_; bool enable_qdq_optimizer_; - bool disable_cpu_fallback_; - bool so_epctx_embed_mode_; + const ConfigOptions& config_options_; }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - OpenVINOExecutionProviderInfo info(device_type_, precision_, enable_npu_fast_compile_, num_of_threads_, + bool so_disable_cpu_fallback = config_options_.GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; + bool so_export_ep_ctx_blob = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + bool so_epctx_embed_mode = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; + std::string so_cache_path = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "").c_str(); + + if (so_export_ep_ctx_blob && !so_cache_path.empty()) { + cache_dir_ = std::move(so_cache_path); + auto file_path = std::filesystem::path(cache_dir_); + // ep_context_file_path_ file extension must be .onnx + if (file_path.extension().generic_string() == ".onnx") { + // ep_context_file_path_ must be provided as a directory, create it if doesn't exist + auto parent_path = file_path.parent_path(); + if (!parent_path.empty() && !std::filesystem::is_directory(parent_path) && + !std::filesystem::create_directory(parent_path)) { + ORT_THROW("[ERROR] [OpenVINO] Failed to create directory : " + + file_path.parent_path().generic_string() + " \n"); + } + } else { + ORT_THROW("[ERROR] [OpenVINO] Invalid ep_ctx_file_path" + cache_dir_ + " \n"); + } + } + + OpenVINOExecutionProviderInfo info(device_type_, precision_, num_of_threads_, load_config_, cache_dir_, model_priority_, num_streams_, context_, enable_opencl_throttling_, - disable_dynamic_shapes_, export_ep_ctx_blob_, enable_qdq_optimizer_, - disable_cpu_fallback_, - so_epctx_embed_mode_); + disable_dynamic_shapes_, so_export_ep_ctx_blob, enable_qdq_optimizer_, + so_disable_cpu_fallback, so_epctx_embed_mode); return std::make_unique(info); } @@ -77,41 +94,42 @@ struct OpenVINO_Provider : Provider { void* GetInfo() override { return &g_info; } std::shared_ptr CreateExecutionProviderFactory(const void* void_params) override { - auto& provider_options_map = *reinterpret_cast(void_params); - - std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and precision - // with these values at runtime. - std::string precision = ""; // [precision]: Sets the inference precision for execution. - // Supported precision for devices are CPU=FP32, GPU=FP32,FP16, NPU=FP16. - // Not setting precision will execute with optimized precision for - // best inference latency. set Precision=ACCURACY for executing models - // with input precision for best accuracy. - bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to - // speeds up the model's compilation to NPU device specific format. - int num_of_threads = 0; // [num_of_threads]: Overrides the accelerator default value of number of - // threads with this value at runtime. - std::string cache_dir = ""; // [cache_dir]: specify the path to - // dump and load the blobs for the model caching/kernel caching (GPU) - // feature. If blob files are already present, it will be directly loaded. - const char* model_priority = "DEFAULT"; // High-level OpenVINO model priority hint - // Defines what model should be provided with more performant - // bounded resource first - int num_streams = 1; // [num_streams]: Option that specifies the number of parallel inference - // requests to be processed on a given `device_type`. Overrides the - // accelerator default value of number of streams - // with this value at runtime. - bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU - // device (Reduces CPU Utilization when using GPU) - bool export_ep_ctx_blob = false; // Whether to export the pre-compiled blob as an EPContext model. + // Extract the void_params into ProviderOptions and ConfigOptions + typedef std::pair ConfigBuffer; + const ConfigBuffer* buffer = reinterpret_cast(void_params); + auto& provider_options_map = *buffer->first; + const ConfigOptions& config_options = buffer->second; + + std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and + // precision with these values at runtime. + std::string precision = ""; // [precision]: Sets the inference precision for execution. + // Supported precision for devices are + // CPU=FP32, GPU=FP32,FP16, NPU=FP16. + // Not setting precision will execute with optimized precision for + // best inference latency. set Precision=ACCURACY for executing + // models with input precision for best accuracy. + int num_of_threads = 0; // [num_of_threads]: Overrides the accelerator default value of + // number of threads with this value at runtime. + std::map load_config; // JSON config map to load custom OV parameters. + std::string cache_dir = ""; // [cache_dir]: specify the path to + // dump and load the blobs for the model caching/kernel caching + // (GPU) feature. If blob files are already present, + // it will be directly loaded. + std::string model_priority = "DEFAULT"; // High-level OpenVINO model priority hint + // Defines what model should be provided with more performant + // bounded resource first + int num_streams = 1; // [num_streams]: Option that specifies the number of parallel + // inference requests to be processed on a given `device_type`. + // Overrides the accelerator default value of number of streams + // with this value at runtime. + bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for + // GPU device (Reduces CPU Utilization when using GPU) + + bool enable_qdq_optimizer = false; // Enables QDQ pruning for efficient inference latency with NPU void* context = nullptr; - bool enable_qdq_optimizer = false; - - bool disable_cpu_fallback = false; - - bool so_epctx_embed_mode = true; - + std::string bool_flag = ""; if (provider_options_map.find("device_type") != provider_options_map.end()) { device_type = provider_options_map.at("device_type").c_str(); @@ -185,6 +203,68 @@ struct OpenVINO_Provider : Provider { cache_dir = provider_options_map.at("cache_dir"); } + if (provider_options_map.find("load_config") != provider_options_map.end()) { + auto parse_config = [&](const std::string& config_str) -> std::map { + // If the config string is empty, return an empty map and skip processing + if (config_str.empty()) { + LOGS_DEFAULT(WARNING) << "Empty OV Config Map passed. Skipping load_config option parsing.\n"; + return {}; + } + + std::stringstream input_str_stream(config_str); + std::map target_map; + + try { + nlohmann::json json_config = nlohmann::json::parse(input_str_stream); + + if (!json_config.is_object()) { + ORT_THROW("Invalid JSON structure: Expected an object at the root."); + } + + for (auto& [key, value] : json_config.items()) { + ov::AnyMap inner_map; + + // Ensure the key is one of "CPU", "GPU", or "NPU" + if (key != "CPU" && key != "GPU" && key != "NPU") { + LOGS_DEFAULT(WARNING) << "Unsupported device key: " << key << ". Skipping entry.\n"; + continue; + } + + // Ensure that the value for each device is an object (PROPERTY -> VALUE) + if (!value.is_object()) { + ORT_THROW("Invalid JSON structure: Expected an object for device properties."); + } + + for (auto& [inner_key, inner_value] : value.items()) { + if (inner_value.is_string()) { + inner_map[inner_key] = inner_value.get(); + } else if (inner_value.is_number_integer()) { + inner_map[inner_key] = inner_value.get(); + } else if (inner_value.is_number_float()) { + inner_map[inner_key] = inner_value.get(); + } else if (inner_value.is_boolean()) { + inner_map[inner_key] = inner_value.get(); + } else { + LOGS_DEFAULT(WARNING) << "Unsupported JSON value type for key: " << inner_key << ". Skipping key."; + } + } + target_map[key] = std::move(inner_map); + } + } catch (const nlohmann::json::parse_error& e) { + // Handle syntax errors in JSON + ORT_THROW("JSON parsing error: " + std::string(e.what())); + } catch (const nlohmann::json::type_error& e) { + // Handle invalid type accesses + ORT_THROW("JSON type error: " + std::string(e.what())); + } catch (const std::exception& e) { + ORT_THROW("Error parsing load_config Map: " + std::string(e.what())); + } + return target_map; + }; + + load_config = parse_config(provider_options_map.at("load_config")); + } + if (provider_options_map.find("context") != provider_options_map.end()) { std::string str = provider_options_map.at("context"); uint64_t number = std::strtoull(str.c_str(), nullptr, 16); @@ -224,16 +304,6 @@ struct OpenVINO_Provider : Provider { << "Executing with num_streams=1"; } } - std::string bool_flag = ""; - if (provider_options_map.find("enable_npu_fast_compile") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_npu_fast_compile"); - if (bool_flag == "true" || bool_flag == "True") - enable_npu_fast_compile = true; - else if (bool_flag == "false" || bool_flag == "False") - enable_npu_fast_compile = false; - bool_flag = ""; - } - if (provider_options_map.find("enable_opencl_throttling") != provider_options_map.end()) { bool_flag = provider_options_map.at("enable_opencl_throttling"); if (bool_flag == "true" || bool_flag == "True") @@ -249,6 +319,8 @@ struct OpenVINO_Provider : Provider { enable_qdq_optimizer = true; else if (bool_flag == "false" || bool_flag == "False") enable_qdq_optimizer = false; + else + ORT_THROW("[ERROR] [OpenVINO-EP] enable_qdq_optimiser should be a boolean.\n"); bool_flag = ""; } @@ -271,68 +343,21 @@ struct OpenVINO_Provider : Provider { disable_dynamic_shapes = false; } } - } - if (provider_options_map.find("so_export_ep_ctx_blob") != provider_options_map.end()) { - bool_flag = provider_options_map.at("so_export_ep_ctx_blob"); - if (bool_flag == "true" || bool_flag == "True") - export_ep_ctx_blob = true; - else if (bool_flag == "false" || bool_flag == "False") - export_ep_ctx_blob = false; - bool_flag = ""; - } - - if (provider_options_map.find("disable_cpu_fallback") != provider_options_map.end()) { - bool_flag = provider_options_map.at("disable_cpu_fallback"); - if (bool_flag == "true" || bool_flag == "True") - disable_cpu_fallback = true; - else if (bool_flag == "false" || bool_flag == "False") - disable_cpu_fallback = false; - bool_flag = ""; - } - if (provider_options_map.find("so_epctx_embed_mode") != provider_options_map.end()) { - bool_flag = provider_options_map.at("so_epctx_embed_mode"); - if (bool_flag == "true" || bool_flag == "True") - so_epctx_embed_mode = true; - else if (bool_flag == "false" || bool_flag == "False") - so_epctx_embed_mode = false; bool_flag = ""; } - if (provider_options_map.find("so_epctx_path") != provider_options_map.end()) { - // The path to dump epctx model is valid only when epctx is enabled. - // Overrides the cache_dir option to dump model cache files from OV. - if (export_ep_ctx_blob && - !provider_options_map.at("so_epctx_path").empty()) { - cache_dir = provider_options_map.at("so_epctx_path"); - auto file_path = std::filesystem::path(cache_dir); - // ep_context_file_path_ file extension must be .onnx - if (file_path.extension().generic_string() == ".onnx") { - // ep_context_file_path_ must be provided as a directory, create it if doesn't exist - auto parent_path = file_path.parent_path(); - if (!parent_path.empty() && !std::filesystem::is_directory(parent_path) && - !std::filesystem::create_directory(parent_path)) { - ORT_THROW("[ERROR] [OpenVINO] Failed to create directory : " + file_path.parent_path().generic_string() + " \n"); - } - } else { - ORT_THROW("[ERROR] [OpenVINO] Invalid ep_ctx_file_path" + cache_dir + " \n"); - } - } - } - - return std::make_shared(const_cast(device_type.c_str()), - const_cast(precision.c_str()), - enable_npu_fast_compile, + return std::make_shared(device_type, + precision, num_of_threads, - const_cast(cache_dir.c_str()), + load_config, + cache_dir, model_priority, num_streams, context, enable_opencl_throttling, disable_dynamic_shapes, - export_ep_ctx_blob, enable_qdq_optimizer, - disable_cpu_fallback, - so_epctx_embed_mode); + config_options); } void Initialize() override { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h b/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h index bff70a90b6a70..0cbf051c6df26 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h @@ -14,8 +14,7 @@ namespace onnxruntime { struct SessionOptions; // defined in provider_bridge_ort.cc struct OpenVINOProviderFactoryCreator { - static std::shared_ptr Create(ProviderOptions* provider_options_map, + static std::shared_ptr Create(const ProviderOptions* provider_options_map, const SessionOptions* session_options); - static std::shared_ptr Create(const OrtOpenVINOProviderOptions* provider_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 6700244b754d8..0e5ff8ff98efb 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -39,7 +39,6 @@ void* OVRTAllocator::Alloc(size_t size) { } catch (const ov::Exception& e) { ORT_THROW(std::string("Alloc failed: ") + e.what()); } - return nullptr; } void OVRTAllocator::Free(void* p) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7e8681d304abf..12ab7ecede031 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -109,7 +109,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, } } -OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, +OVExeNetwork OVCore::ImportModel(const std::string& model_string, std::string hw_target, const ov::AnyMap& device_config, bool embed_mode, @@ -117,10 +117,10 @@ OVExeNetwork OVCore::ImportModel(std::shared_ptr model_strea try { ov::CompiledModel obj; if (embed_mode) { - obj = oe.import_model(*model_stream, hw_target, device_config); + std::istringstream model_stream(model_string); + obj = oe.import_model(model_stream, hw_target, device_config); } else { - std::string blob_file_path = (*model_stream).str(); - std::ifstream modelStream(blob_file_path, std::ios_base::binary | std::ios_base::in); + std::ifstream modelStream(model_string, std::ios_base::binary | std::ios_base::in); obj = oe.import_model(modelStream, hw_target, {}); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index f4da4ea3e3244..c3417003f8e1f 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -54,7 +54,7 @@ class OVCore { ov::AnyMap& device_config, const std::string& name); // OV Interface for Import model Stream - OVExeNetwork ImportModel(std::shared_ptr model_stream, + OVExeNetwork ImportModel(const std::string& model_string, std::string hw_target, const ov::AnyMap& device_config, bool embed_mode, diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 3fcaff4369c89..3e780f74145ae 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -35,16 +35,14 @@ GetCapability::GetCapability(const GraphViewer& graph_viewer_param, device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } -#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 - data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 - data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 2 - data_ops_ = new DataOps(graph_viewer_, V_2024_2, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 3 - data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled); +#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 4 + data_ops_ = new DataOps(graph_viewer_, V_2024_4, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 + data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 + data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #endif } @@ -59,7 +57,7 @@ std::vector> GetCapability::Execute() { // This is a list of initializers that nGraph considers as constants. Example weights, reshape shape etc. std::unordered_set ng_required_initializers; - const auto unsupported_nodes = data_ops_->GetUnsupportedNodeIndices(ng_required_initializers); + const auto unsupported_nodes = data_ops_->GetUnsupportedNodeIndices(ng_required_initializers, has_external_weights_); #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { std::cout << "No of unsupported nodes " << unsupported_nodes.size() << std::endl; diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.h b/onnxruntime/core/providers/openvino/ov_versions/capability.h index 63c83158accf8..2f87c4c73d892 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.h @@ -16,6 +16,7 @@ class GetCapability { std::string device_type_; DataOps* data_ops_; bool is_wholly_supported_graph_ = false; + bool has_external_weights_ = false; public: GetCapability(const GraphViewer& graph_viewer_param, @@ -25,6 +26,9 @@ class GetCapability { bool IsWhollySupportedGraph() { return is_wholly_supported_graph_; } + bool HasExternalWeights() { + return has_external_weights_; + } }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index d9aa13ec1bba9..f118f057ac11e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -118,6 +118,7 @@ std::vector supported_op_mode = { {"CumSum", V_2022_1, {"CPU", "GPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, + {"DequantizeLinear", V_2024_4, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, @@ -254,6 +255,8 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)); supported_types_initializer_.insert( @@ -262,6 +265,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); supported_types_initializer_.insert( std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_initializer_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_npu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); @@ -281,6 +288,14 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); supported_types_npu_.insert( std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_npu_.insert( + std::make_pair(V_2024_3, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN)); + supported_types_npu_.insert( + std::make_pair(V_2024_3, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ)); + supported_types_npu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_npu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_cpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); @@ -300,6 +315,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); supported_types_cpu_.insert( std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_cpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_gpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); @@ -315,6 +334,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); supported_types_gpu_.insert( std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_gpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); } void DataOps::populate_op_mode_supported() { @@ -328,9 +351,11 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); + no_dimension_supported_.push_back({"Expand", V_2024_3, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Identity", V_2023_0, {"All"}}); + no_dimension_supported_.push_back({"If", V_2022_3, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); @@ -363,7 +388,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -378,7 +403,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"ReduceMax", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, + V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); @@ -395,7 +421,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Reshape", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, + V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -410,7 +437,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, + V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); @@ -583,11 +611,21 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } } -bool DataOps::unsupported_op_mode(const Node* node) { +bool DataOps::unsupported_op_mode(const Node* node, bool& has_external_weights_) { bool result = false; const auto& optype = node->OpType(); const auto& initializers = graph_viewer_.GetAllInitializedTensors(); + for (const auto& tensor_pair : initializers) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_pair.second; + // Check if the tensor exists and if it has an external data location + if (tensor_proto && tensor_proto->has_data_location() && + tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + has_external_weights_ = true; + break; + } + } + auto iter = op_list_.equal_range(optype); for (auto it = iter.first; it != iter.second; ++it) { auto ob = it->second; @@ -637,7 +675,7 @@ bool DataOps::dimension_unsupported(const Node* node) { return true; } -bool DataOps::node_is_supported(const NodeIndex node_idx) { +bool DataOps::node_is_supported(const NodeIndex node_idx, bool& has_external_weights_) { const auto& node = graph_viewer_.GetNode(node_idx); const auto& optype = node->OpType(); @@ -745,7 +783,7 @@ bool DataOps::node_is_supported(const NodeIndex node_idx) { } // Check 3a - if (domain == kOnnxDomain && unsupported_op_mode(node)) { + if (domain == kOnnxDomain && unsupported_op_mode(node, has_external_weights_)) { if (optype == "GatherElements") { return true; } @@ -760,11 +798,12 @@ bool DataOps::node_is_supported(const NodeIndex node_idx) { return true; } -std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers) { +std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers, + bool& has_external_weights_) { std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { - if (node_is_supported(node_idx)) { + if (node_is_supported(node_idx, has_external_weights_)) { // Collect inputs that are initializers graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, bool is_input) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index 4c064b08405c1..07fa36f355d55 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -30,7 +30,10 @@ enum versionNum { V_2024_0, V_2024_1, V_2024_2, - V_2024_3 + V_2024_3, + V_2024_4, + V_2024_5, + V_2025_0 }; using VersionNum = enum versionNum; @@ -70,9 +73,9 @@ class DataOps { void populate_types_supported(); bool op_is_supported(std::string name, std::vector& list); bool dimension_unsupported(const Node* node); - bool unsupported_op_mode(const Node* node); + bool unsupported_op_mode(const Node* node, bool& has_external_weights_); bool type_is_supported(const NodeArg* node_arg, bool is_initializer); - bool node_is_supported(const NodeIndex node_idx); + bool node_is_supported(const NodeIndex node_idx, bool& has_external_weights_); public: DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, @@ -85,7 +88,8 @@ class DataOps { populate_types_supported(); } - virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); + virtual std::vector GetUnsupportedNodeIndices( + std::unordered_set& ng_required_initializers, bool& has_external_weights_); virtual bool IsOpSupportedOnlyInModel(std::string name); virtual bool SpecialConditionForClusterSizeOne( std::unordered_set& ng_required_initializers, const Node* node); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index f1df1abf4c49a..387aaf9985b4c 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -30,6 +30,10 @@ constexpr std::string_view DuplicateDQ = "/duplicated"; constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT16 = ONNX_NAMESPACE::TensorProto_DataType_UINT16; constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT16 = ONNX_NAMESPACE::TensorProto_DataType_INT16; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT8 = ONNX_NAMESPACE::TensorProto_DataType_UINT8; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT8 = ONNX_NAMESPACE::TensorProto_DataType_INT8; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT4 = ONNX_NAMESPACE::TensorProto_DataType_UINT4; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT4 = ONNX_NAMESPACE::TensorProto_DataType_INT4; // Return the data type of the qdq node. // Check output type of Q and input type of DQ to determine it as zero_point is an optional input and may not exist @@ -218,7 +222,7 @@ static bool DQFeedsASupportedOp(const Node* dq_node) { } else { return true; } - } else if (op_type == "Add") { + } else if (op_type == "Add" && !(GetQDQDataType(dq_node) == DT_UINT16 || GetQDQDataType(dq_node) == DT_INT16)) { // Add => keeps all DQs return true; } @@ -687,7 +691,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger); std::unordered_set seen_node_units; const auto& node_indices = src_graph.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 41e418d9eb97f..1c62c1a7a8d0b 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" #endif -#if defined(USE_TVM) -#include "core/providers/tvm/tvm_provider_factory_creator.h" -#endif - #if defined(USE_VITISAI) #include "core/providers/vitisai/vitisai_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 57ae8c354abb7..79674fd706151 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); +} + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list) { + max_spill_fill_size = 0; + int max_size_index = 0; + for (uint32_t i = 0; i < total_context_size; ++i) { + auto index = main_context_pos_list[i]; + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph); + ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); + const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin(); + NodeAttrHelper node_helper(*ep_context_node); + int64_t max_size = node_helper.Get(MAX_SIZE, static_cast(0)); + if (max_size > max_spill_fill_size) { + max_spill_fill_size = max_size; + max_size_index = i; + } + } + if (0 != max_size_index) { + int tmp_index = main_context_pos_list[0]; + main_context_pos_list[0] = main_context_pos_list[max_size_index]; + main_context_pos_list[max_size_index] = tmp_index; + } + + return Status::OK(); } Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger) { + const logging::Logger& logger, + int64_t max_spill_fill_size) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - qnn_models); + qnn_models, max_spill_fill_size); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { @@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model, const QnnModelLookupTable& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { auto& graph = model->MainGraph(); @@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model, } of_stream.write(reinterpret_cast(buffer), buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name); + ep_node.AddAttribute(MAX_SIZE, static_cast(max_spill_fill_buffer_size)); } } else { ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index f308a7456d46c..92c5391b40f09 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; +static const std::string MAX_SIZE = "max_size"; bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); @@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models); + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size); + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger); + const logging::Logger& logger, + int64_t max_spill_fill_size); Status CreateEPContextNodes(Model* model, unsigned char* buffer, @@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model, const std::unordered_map>& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index dd5c6a5a79cdb..6ef17b40d274b 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -83,6 +83,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateReduceOpBuilder("ReduceMin", *this); CreateReduceOpBuilder("ReduceProd", *this); CreateReduceOpBuilder("ReduceSum", *this); + CreateReduceOpBuilder("ReduceL2", *this); } { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index 5c4608dff9bb1..d1a0e88686f39 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,10 +87,10 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18 || QNN_API_VERSION_MINOR == 19) +#if QNN_API_VERSION_MAJOR == 2 && QNN_API_VERSION_MINOR >= 17 && QNN_API_VERSION_MINOR <= 20 if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24/2.25/2.26 (QNN API version 2.17/2.18/2.19) has a validation bug for implicit bias inputs, - // so provide an explicit bias of all 0 (quantized int32). + // Bias is implicit. QNN SDK 2.24 to 2.27 (QNN API version 2.17 to 2.20) has a validation bug for + // implicit bias inputs, so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index 2aefe5f6b8e71..77bc58bd6f833 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -6,15 +6,15 @@ #include #include +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" #include "core/framework/endian_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" - -#include "base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { namespace qnn { @@ -25,6 +25,7 @@ enum ReduceOpType { REDUCE_OP_TYPE_MEAN, REDUCE_OP_TYPE_PROD, REDUCE_OP_TYPE_SUM, + REDUCE_OP_TYPE_L2, REDUCE_OP_TYPE_COUNT, REDUCE_OP_TYPE_UNKNOWN, @@ -41,6 +42,8 @@ ReduceOpType GetReduceOpType(const std::string& op_type) { return REDUCE_OP_TYPE_PROD; } else if (op_type == "ReduceSum") { return REDUCE_OP_TYPE_SUM; + } else if (op_type == "ReduceL2") { + return REDUCE_OP_TYPE_L2; } else { return REDUCE_OP_TYPE_UNKNOWN; } @@ -51,21 +54,16 @@ class ReduceOpBuilder : public BaseOpBuilder { ReduceOpBuilder() : BaseOpBuilder("ReduceOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ReduceOpBuilder); - Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: - Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, std::vector& input_names, bool do_op_validation = false) const override ORT_MUST_USE_RESULT; - Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; private: @@ -84,7 +82,8 @@ const std::array ReduceOpBuilder::opset_with_axes_as_ 18, // ReduceMin 18, // ReduceMean 18, // ReduceProd - 13 // ReduceSum + 13, // ReduceSum + 18, // ReduceL2 }; Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -175,8 +174,7 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod return Status::OK(); } -Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, +Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { ReduceOpType reduce_op_type = GetReduceOpType(node_unit.OpType()); if (reduce_op_type == ReduceOpType::REDUCE_OP_TYPE_UNKNOWN) { @@ -188,13 +186,17 @@ Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: ReduceProd operator not supported by HTP backend."); } + // ReduceL2 is composed by Mul->ReduceSum->Sqrt, it's not easy to set the quantization parameters for the activation + // tensors between, so we don't support ReduceL2 with quantized input for now. + if (reduce_op_type == ReduceOpType::REDUCE_OP_TYPE_L2 && node_unit.Inputs()[0].quant_param.has_value()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: ReduceL2 operator does not support quantized input for now."); + } + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } -Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, - std::vector& input_names, +Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const { ORT_UNUSED_PARAMETER(do_op_validation); @@ -207,11 +209,9 @@ Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, +Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { + const logging::Logger& logger, bool do_op_validation) const { NodeAttrHelper node_attr_helper(node_unit); std::vector param_tensor_names; @@ -229,8 +229,8 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::transform(axes_set.begin(), axes_set.end(), axes_data.begin(), [](AxesOnnxIntType item) { return SafeInt(item); }); - QnnParamWrapper axes_param(node_unit.Index(), node_unit.Name(), QNN_OP_REDUCE_MAX_PARAM_AXES, - std::move(axes_shape), std::move(axes_data)); + QnnParamWrapper axes_param(node_unit.Index(), node_unit.Name(), QNN_OP_REDUCE_MAX_PARAM_AXES, std::move(axes_shape), + std::move(axes_data)); param_tensor_names.push_back(axes_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(axes_param)); @@ -245,10 +245,57 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w param_tensor_names.push_back(keep_dims_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(keep_dims_param)); - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, do_op_validation, GetQnnOpType(node_unit.OpType()))); + if (node_unit.OpType() == "ReduceL2") { + // If ReduceL2, QNN doesn't have a single Op for it, we need to add a + // ElementWiseMultiply->ReduceSum->ElementWiseSquareRoot node sequence. + const auto& input = node_unit.Inputs()[0]; + const auto& output = node_unit.Outputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, input_shape), "Cannot get input shape."); + std::vector output_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Cannot get output shape."); + ORT_ENFORCE(!input.quant_param.has_value(), "Input tensor must not be quantized."); + const auto* type_proto = output.node_arg.TypeAsProto(); + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type)); + const std::string input_name = input_names[0]; + + // Step 1: y_pow2 = x * x, using ElementWiseMultiply instead of ElementWisePower so we don't need to add a new + // initializer tensor for the power value. The performance difference is negligible. + const std::string pow2_name = input_name + "_ort_qnn_ep_pow2"; + QnnTensorWrapper pow2_tensorwrapper(pow2_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + std::move(input_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(pow2_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(pow2_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_MULTIPLY, + {input_name, input_name}, {pow2_name}, {}, do_op_validation), + "CreateQnnNode failed"); + + // Step 2: y_pow2_sum = ReduceSum(y_pow2) + const std::string reduce_name = input_name + "_ort_qnn_ep_pow2_sum"; + QnnTensorWrapper reduce_tensorwrapper(reduce_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reduce_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_REDUCE_SUM, + {pow2_name}, {reduce_name}, std::move(param_tensor_names), do_op_validation), + "CreateQnnNode failed"); + + // Step 3: y = Sqrt(y_pow2_sum) + Qnn_TensorType_t output_tensor_type = + qnn_model_wrapper.IsGraphOutput(output.node_arg.Name()) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper sqrt_tensorwrapper(output.node_arg.Name(), output_tensor_type, qnn_data_type, + QnnQuantParamsWrapper(), std::move(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sqrt_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(input_name + "_ort_qnn_ep_pow2_sum_sqrt", + QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_SQUARE_ROOT, + {reduce_name}, {output.node_arg.Name()}, {}, do_op_validation), + "CreateQnnNode failed"); + } else { + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), + std::move(param_tensor_names), logger, do_op_validation, + GetQnnOpType(node_unit.OpType()))); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 0358fae3c2115..a6c4203ad92e4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -164,6 +164,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Inputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone DQ op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(node_unit.Outputs()[0].node_arg.Name()), + "QNN EP is configured to not take DQ nodes that generate a graph output."); + } } if (op_type == "QuantizeLinear") { @@ -171,6 +176,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Outputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone Q op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphInput(node_unit.Inputs()[0].node_arg.Name()), + "QNN EP is configured to not take Q nodes that consume a graph input."); + } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index eaffe1e2ac224..3af646c3ce13a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -8,12 +8,14 @@ #include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" +#include "HTP/QnnHtpSystemContext.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpContext.h" +#include "Saver/QnnSaver.h" #include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" @@ -302,13 +304,21 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } Status QnnBackendManager::ResetQnnLogLevel() { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + std::lock_guard lock(logger_mutex_); + + if (backend_setup_completed_ && logger_ != nullptr) { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); + } + return Status::OK(); } Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); + ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -523,11 +533,11 @@ Status QnnBackendManager::CreateContext() { } QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT; - QnnHtpContext_CustomConfig_t customConfig; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = enable_htp_weight_sharing_; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + custom_config.weightSharingEnabled = enable_htp_weight_sharing_; context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - context_config_weight_sharing.customConfig = &customConfig; + context_config_weight_sharing.customConfig = &custom_config; QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); @@ -606,9 +616,78 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } +Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size) { + max_spill_fill_buffer_size = 0; + // spill fill starts from 2.28 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) + bool result = nullptr == qnn_sys_interface_.systemContextCreate || + nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || + nullptr == qnn_sys_interface_.systemContextFree; + ORT_RETURN_IF(result, "Failed to get valid function pointer."); + + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; + auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); + + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size{0}; + rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, + static_cast(buffer), + buffer_length, + &binary_info, + &binary_info_size); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + + // binary_info life cycle is here + // Binary info to graph info + // retrieve Qnn graph info from binary info + ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); + uint32_t graph_count = 0; + QnnSystemContext_GraphInfo_t* graphs_info = nullptr; + if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + graphs_info = binary_info->contextBinaryInfoV3.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + graph_count = binary_info->contextBinaryInfoV2.numGraphs; + graphs_info = binary_info->contextBinaryInfoV2.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + graph_count = binary_info->contextBinaryInfoV1.numGraphs; + graphs_info = binary_info->contextBinaryInfoV1.graphs; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } + + for (uint32_t i = 0; i < graph_count; ++i) { + if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo); + if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) { + auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize; + max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version."; + } + } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 || + graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2."; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version."; + } + } +#else + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_length); +#endif + + LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed."; + return Status::OK(); +} + Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -629,28 +708,60 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // binary_info life cycle is here // Binary info to graph info - // retrieve Qnn graph infor from binary info + // retrieve Qnn graph info from binary info ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); uint32_t graph_count = 0; QnnSystemContext_GraphInfo_t* graphs_info = nullptr; if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { graph_count = binary_info->contextBinaryInfoV1.numGraphs; graphs_info = binary_info->contextBinaryInfoV1.graphs; - } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + } +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 15) // starts from 2.22 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { graph_count = binary_info->contextBinaryInfoV2.numGraphs; graphs_info = binary_info->contextBinaryInfoV2.graphs; } +#endif +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // starts from 2.28 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + graphs_info = binary_info->contextBinaryInfoV3.graphs; + } +#endif + else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; - ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, - "Invalid function pointer for contextCreateFromBinary."); - QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); - const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + // Register spill fill buffer for multi context + QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; + + // The spill fill buffer is available since 2.28, API version starts from 2.21 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + size_t current_contexts_size = GetQnnContextSize(); + // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle + // note that we already move the context with max spill fill size to the beginning of the list + group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0; + group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0 + custom_config.groupRegistration = group_info; + spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + spill_fill_config.customConfig = &custom_config; +#endif + QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr; + LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size; + + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); Qnn_ContextHandle_t context = nullptr; rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, @@ -659,7 +770,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, &context, profile_backend_handle_); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); contexts_.push_back(context); if (1 == graph_count) { // in case the EPContext node is generated from script @@ -685,7 +796,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t return Status::OK(); } -Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { +// need to load system lib if load from Qnn context binary +// or generate Qnn context binary is enabled -- to get the max spill fill buffer size +Status QnnBackendManager::SetupBackend(const logging::Logger& logger, + bool load_from_cached_context, + bool need_load_system_lib) { + std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; return Status::OK(); @@ -699,7 +815,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "LoadBackend succeed."; - if (load_from_cached_context) { + if (load_from_cached_context || need_load_system_lib) { ORT_RETURN_IF_ERROR(LoadQnnSystemLib()); } @@ -918,20 +1034,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ return Status::OK(); } -void QnnBackendManager::Split(std::vector& split_string, - const std::string& tokenized_string, - const char separator) { - split_string.clear(); - std::istringstream tokenized_string_stream(tokenized_string); - while (!tokenized_string_stream.eof()) { - std::string value; - getline(tokenized_string_stream, value, separator); - if (!value.empty()) { - split_string.push_back(value); - } - } -} - Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); @@ -972,6 +1074,7 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } + std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); @@ -1025,7 +1128,14 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); - ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + if (!qnn_saver_path_.empty()) { // Using QNN Saver backend + // QNN SDK 2.28.2 returns QNN_SAVER_ERROR_DUMMY_RETVALUE, but previous QNN versions return QNN_PROFILE_NO_ERROR. + // We accept both values. + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result && QNN_SAVER_ERROR_DUMMY_RETVALUE != result, + "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } else { + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } if (num_events > 0) { LOGS(*logger_, VERBOSE) << "profile_events: " << profile_events << " num_events: " << num_events; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b80f1374fcdc7..b145f2a2cd724 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -12,9 +12,11 @@ #endif #include +#include #include #include #include + #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" #include "QnnTypes.h" @@ -91,9 +93,10 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + int64_t max_spill_fill_size); - Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); + Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -110,6 +113,10 @@ class QnnBackendManager { return contexts_[index]; } + size_t GetQnnContextSize() { + return contexts_.size(); + } + const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; } const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } @@ -143,8 +150,6 @@ class QnnBackendManager { void ReleaseResources(); - void Split(std::vector& split_string, const std::string& tokenized_string, const char separator); - Status ExtractBackendProfilingInfo(); Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); @@ -161,6 +166,10 @@ class QnnBackendManager { Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + Status GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -233,6 +242,7 @@ class QnnBackendManager { private: const std::string backend_path_; + std::mutex logger_mutex_; const logging::Logger* logger_ = nullptr; QNN_INTERFACE_VER_TYPE qnn_interface_ = QNN_INTERFACE_VER_TYPE_INIT; QNN_SYSTEM_INTERFACE_VER_TYPE qnn_sys_interface_ = QNN_SYSTEM_INTERFACE_VER_TYPE_INIT; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index f322456e0c8f0..4f73e4c532ed4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -95,6 +95,7 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node, Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs) { LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); @@ -103,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); @@ -115,7 +116,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, model_input_index_map_, model_output_index_map_, initializer_inputs_, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings); bool rt = true; rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name, graph_configs); if (!rt) { @@ -245,7 +247,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: { // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run() // from multiple threads. - std::lock_guard lock(graph_exec_mutex_); + std::lock_guard lock(graph_exec_mutex_); execute_status = qnn_interface.graphExecute(graph_info_->Graph(), qnn_inputs.data(), static_cast(qnn_inputs.size()), @@ -319,29 +321,57 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph std::vector output_tensor_wrappers; std::string graph_name; + Qnn_Tensor_t* input_tensors = nullptr; + Qnn_Tensor_t* output_tensors = nullptr; + uint32_t graph_input_num = 0; + uint32_t graph_output_num = 0; if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV1.graphName); - auto graph_input_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphInputs; - auto graph_output_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphOutputs; - ORT_RETURN_IF(nullptr == qnn_sys_ctx_graph_info.graphInfoV1.graphInputs, "Graph from cached context doesn't have any inputs."); - ORT_RETURN_IF(nullptr == qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs, "Graph from cached context doesn't have any outputs."); - - // Copy graph input - Qnn_Tensor_t* input_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphInputs; - for (size_t i = 0; i < graph_input_num; ++i) { - QnnTensorWrapper tensorwrapper; - ORT_RETURN_IF_ERROR(tensorwrapper.Init(input_tensors[i])); - input_tensor_wrappers.push_back(std::move(tensorwrapper)); - } + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphOutputs; - // Copy graph output - Qnn_Tensor_t* output_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs; - for (size_t i = 0; i < graph_output_num; ++i) { - QnnTensorWrapper tensorwrapper; - ORT_RETURN_IF_ERROR(tensorwrapper.Init(output_tensors[i])); - output_tensor_wrappers.push_back(std::move(tensorwrapper)); - } + input_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs; + } +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 18) // start from 2.25 + else if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { + graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV2.graphName); + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV2.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV2.numGraphOutputs; + + input_tensors = qnn_sys_ctx_graph_info.graphInfoV2.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV2.graphOutputs; + } +#endif +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // start from 2.28 + else if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV3.graphName); + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV3.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV3.numGraphOutputs; + + input_tensors = qnn_sys_ctx_graph_info.graphInfoV3.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV3.graphOutputs; } +#endif + else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context graph info version."); + } + ORT_RETURN_IF(nullptr == input_tensors, "Graph from cached context doesn't have any inputs."); + ORT_RETURN_IF(nullptr == output_tensors, "Graph from cached context doesn't have any outputs."); + + // Copy graph input + for (size_t i = 0; i < graph_input_num; ++i) { + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(input_tensors[i])); + input_tensor_wrappers.push_back(std::move(tensorwrapper)); + } + // Copy graph output + for (size_t i = 0; i < graph_output_num; ++i) { + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(output_tensors[i])); + output_tensor_wrappers.push_back(std::move(tensorwrapper)); + } + Qnn_GraphHandle_t graph; auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto rt = qnn_interface.graphRetrieve(context, graph_name.c_str(), &graph); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 83cf8f9f08fb0..2e0935391ca78 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -8,7 +8,7 @@ #include "core/common/status.h" #include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" @@ -35,6 +35,7 @@ class QnnModel { Status ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs = nullptr); @@ -142,7 +143,7 @@ class QnnModel { QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; // Mutex acquired during graph execution to support multi-threaded inference of a single session. - OrtMutex graph_exec_mutex_; + std::mutex graph_exec_mutex_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ab122b7f8e28..f3e52050e79e0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -29,6 +29,10 @@ struct TensorInfo { const ONNX_NAMESPACE::TensorProto* initializer_tensor; }; +struct ModelSettings { + bool offload_graph_io_quantization = false; +}; + class QnnModelWrapper { public: QnnModelWrapper(const GraphViewer& graph_viewer, @@ -38,7 +42,8 @@ class QnnModelWrapper { const std::unordered_map& input_index_map, const std::unordered_map& output_index_map, const std::unordered_set& initializer_lookup, - QnnBackendType qnn_backend_type) + QnnBackendType qnn_backend_type, + const ModelSettings& model_settings) : graph_viewer_(graph_viewer), logger_(logger), qnn_interface_(qnn_interface), @@ -46,12 +51,15 @@ class QnnModelWrapper { input_index_map_(input_index_map), output_index_map_(output_index_map), initializer_lookup_(initializer_lookup), - qnn_backend_type_(qnn_backend_type) { + qnn_backend_type_(qnn_backend_type), + model_settings_(model_settings) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModelWrapper); ~QnnModelWrapper() = default; + const ModelSettings& GetModelSettings() const { return model_settings_; } + bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); @@ -279,6 +287,7 @@ class QnnModelWrapper { const std::unordered_map& output_index_map_; const std::unordered_set& initializer_lookup_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + ModelSettings model_settings_ = {}; }; // QnnModelWrapper } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 24132b98e3757..27e195dea73d2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -36,8 +36,8 @@ constexpr const char* QNN = "QNN"; static std::unique_ptr>> s_run_on_unload_; void RunOnUnload(std::function function) { - static OrtMutex mutex; - std::lock_guard guard(mutex); + static std::mutex mutex; + std::lock_guard guard(mutex); if (!s_run_on_unload_) { s_run_on_unload_ = std::make_unique>>(); } @@ -161,6 +161,23 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic } } +static bool ParseBoolOption(const std::string& key, bool default_value, + const std::unordered_map& options) { + bool result = default_value; + auto it = options.find(key); + if (it != options.end()) { + if ("1" == it->second) { + result = true; + } else if ("0" == it->second) { + result = false; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid value for " << key << " (" << it->second << "). Only 0 or 1 allowed."; + } + LOGS_DEFAULT(VERBOSE) << "Using " << key << ": " << result; + } + return result; +} + qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned char level) { if (level == 5) { LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level); @@ -187,7 +204,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled_; std::string embed_mode = session_options->config_options.GetConfigOrDefault( - kOrtSessionOptionEpContextEmbedMode, "1"); + kOrtSessionOptionEpContextEmbedMode, "0"); if ("1" == embed_mode) { qnn_context_embed_mode_ = true; } else if ("0" == embed_mode) { @@ -241,49 +258,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); - } - } - } - - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); -#endif - // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { @@ -389,18 +363,31 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_; } + bool enable_htp_weight_sharing = false; static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing"; auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED); if (htp_weight_sharing_enabled_pos != provider_options_map.end()) { if ("1" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = true; + enable_htp_weight_sharing = true; } else if ("0" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = false; + enable_htp_weight_sharing = false; } else { - LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_ + LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing << " only 0 or 1 allowed. Set to 0."; } - LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; + LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing; + } + + // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform + enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map); + + model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, + provider_options_map); + + if (disable_cpu_ep_fallback_ && model_settings_.offload_graph_io_quantization) { + LOGS_DEFAULT(WARNING) << "Fallback to CPU EP is disabled, but user configured QNN EP to offload graph I/O " + << "quantization/dequantization to another EP. Session creation will fail if the CPU EP " + << "handles the graph I/O quantization/dequantization."; } qnn_backend_manager_ = std::make_unique( @@ -413,12 +400,55 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio device_id_, htp_arch, soc_model, - enable_htp_weight_sharing_); + enable_htp_weight_sharing); + +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); +#endif } QNNExecutionProvider::~QNNExecutionProvider() { // clean up thread local context caches - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { const auto cache = cache_weak.lock(); if (!cache) continue; @@ -427,7 +457,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #ifdef _WIN32 - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } @@ -499,7 +531,8 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, model_input_index_map, model_output_index_map, initializer_input_lookup, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings_); std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_size); @@ -657,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // It will load the QnnSystem lib if is_qnn_ctx_model=true, and // delay the Qnn context creation to Compile() using the cached context binary - auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); + // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -684,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // remove is_qnn_ctx_model related code const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, @@ -845,7 +879,8 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vectorComposeGraph(graph_viewer, fused_node, logger, graph_configs_builder.GetQnnConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, model_settings_, logger, + graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); @@ -904,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::vector main_context_pos_list; ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list)); + uint32_t total_context_size = SafeInt(main_context_pos_list.size()); + + int64_t max_spill_fill_size = 0; + + // Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning + // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28 + if (total_context_size > 1) { + ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size, + max_spill_fill_size, main_context_pos_list)); + } for (auto main_context_pos : main_context_pos_list) { const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); @@ -912,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused context_cache_path, qnn_backend_manager_.get(), qnn_models, - logger)); + logger, + max_spill_fill_size)); } for (auto fused_node_and_graph : fused_nodes_and_graphs) { @@ -954,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); + // Get max spill fill buffer size + uint64_t max_spill_fill_buffer_size = 0; + if (enable_spill_fill_buffer_) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(), + buffer_size, + max_spill_fill_buffer_size)); + } qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), context_buffer.get(), @@ -963,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_models_, context_cache_path, qnn_context_embed_mode_, + max_spill_fill_buffer_size, logger)); } return Status::OK(); @@ -1022,7 +1076,7 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex // get context and update cache std::shared_ptr context; { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); // get or create a context if (context_state_.retired_context_pool.empty()) { @@ -1056,7 +1110,7 @@ void QNNExecutionProvider::ReleasePerThreadContext() const { ORT_ENFORCE(cached_context); { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); context_state_.active_contexts.erase(cached_context); context_state_.retired_context_pool.push_back(cached_context); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e0eaf31c94a36..a0577e8fd87f2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -31,7 +31,7 @@ class SharedContext { } bool HasSharedQnnModels() { - const std::lock_guard lock(mtx_); + const std::lock_guard lock(mtx_); return !shared_qnn_models_.empty(); } @@ -42,7 +42,7 @@ class SharedContext { } std::unique_ptr GetSharedQnnModel(const std::string& model_name) { - const std::lock_guard lock(mtx_); + const std::lock_guard lock(mtx_); auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); if (it == shared_qnn_models_.end()) { @@ -55,7 +55,7 @@ class SharedContext { bool SetSharedQnnModel(std::vector>&& shared_qnn_models, std::string& duplicate_graph_names) { - const std::lock_guard lock(mtx_); + const std::lock_guard lock(mtx_); bool graph_exist = false; for (auto& shared_qnn_model : shared_qnn_models) { auto& model_name = shared_qnn_model->Name(); @@ -81,7 +81,7 @@ class SharedContext { std::vector> shared_qnn_models_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized - OrtMutex mtx_; + std::mutex mtx_; }; // Logical device representation. @@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; - bool enable_htp_weight_sharing_ = false; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; @@ -150,9 +149,11 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool enable_spill_fill_buffer_ = false; #ifdef _WIN32 - onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif + qnn::ModelSettings model_settings_ = {}; class PerThreadContext final { public: @@ -201,7 +202,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::set, std::owner_less>> caches_to_update_on_destruction; // synchronizes access to PerThreadContextState members - OrtMutex mutex; + std::mutex mutex; }; // The execution provider maintains the PerThreadContexts in this structure. diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index cdb4d1f7edac6..b8fe875ba54b7 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -5,9 +5,12 @@ #include #include #include +#include #include +#include #include #include +//#include #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -242,12 +245,63 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// Note that there is no consistent canonical NaN for FP16 and BF16; +// HIP uses 0x7FFF for HIPRT_NAN_BF16, but ONNX Runtime uses 0x7FC1. +// (see BFloat16Impl::kPositiveQNaNBits). +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } +template <> +__device__ __inline__ float _Min(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + +template <> +__device__ __inline__ double _Min(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + +template <> +__device__ __inline__ half _Min(half a, half b) { + return __hmin_nan(a, b); +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } +template <> +__device__ __inline__ float _Max(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + +template <> +__device__ __inline__ double _Max(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + +template <> +__device__ __inline__ half _Max(half a, half b) { + return __hmax_nan(a, b); +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +} + +#undef ISNAN_BFLOAT16 +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } @@ -443,36 +497,36 @@ struct _IsNan { template <> struct _IsNan { __device__ __inline__ bool operator()(half a) const { - return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) - > MLFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; } }; template <> struct _IsNan { __device__ __inline__ bool operator()(BFloat16 a) const { - return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) - > BFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; } }; #if !defined(DISABLE_FLOAT8_TYPES) -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FN a) const { return (*reinterpret_cast(&a) & 0x7f) == 0x7f; } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { return *reinterpret_cast(&a) == 0x80; } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2 a) const { uint8_t c = *reinterpret_cast(&a); @@ -480,7 +534,7 @@ struct _IsNan { } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { return *reinterpret_cast(&a) == 0x80; diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index d130758bec084..18edb359f6062 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -53,29 +53,29 @@ __global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onn } // namespace -rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { +hipblasStatus_t hipblasTransposeHelper(hipStream_t stream, hipblasHandle_t, hipblasOperation_t , hipblasOperation_t , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { if (C != A) { dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1); transposeNoOverlap<<>>(C, A, n, m); } else { - return rocblas_status_not_implemented; + return HIPBLAS_STATUS_NOT_SUPPORTED; } - return rocblas_status_success; + return HIPBLAS_STATUS_SUCCESS; } -rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) { +hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t, int n, const half* x, int incx, half* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); CopyVectorHalf<<>>(x, incx, y, incy, n); - return rocblas_status_success; + return HIPBLAS_STATUS_SUCCESS; } -rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx, +hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t, int n, const onnxruntime::BFloat16* x, int incx, onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); CopyVectorBFloat16<<>>(x, incx, y, incy, n); - return rocblas_status_success; + return HIPBLAS_STATUS_SUCCESS; } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 635a25480b646..281a6f35a2808 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -6,10 +6,8 @@ #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_common.h" +// If you make change below, please also update onnxruntime/core/providers/migraphx/gpu_data_transfer.cc namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || @@ -30,19 +28,23 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -68,15 +71,15 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d297bdce4a93..3d35ed52fff5c 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc index 9771f42fd3637..2d6ee89239cee 100644 --- a/onnxruntime/core/providers/rocm/integer_gemm.cc +++ b/onnxruntime/core/providers/rocm/integer_gemm.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include #include "core/providers/rocm/shared_inc/integer_gemm.h" #include "core/common/safeint.h" @@ -27,7 +25,7 @@ Status GemmInt8(int m, int n, int k, hipStream_t stream = static_cast(ort_stream->GetHandle()); // pad A and B to make their leading dimension be multiples of 32 - // because rocblas_gemm_ex requires: + // because hipblasGemmEx requires: // 1. leading dimension is multiples of 4 // 2. A, B is 32-bit aligned @@ -49,21 +47,19 @@ Status GemmInt8(int m, int n, int k, } auto* ort_rocm_stream = dynamic_cast(ort_stream); - auto rocblas = ort_rocm_stream->rocblas_handle_; + auto hipblas = ort_rocm_stream->hipblas_handle_; - ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex( - rocblas, - rocblas_operation_none, rocblas_operation_none, + HIPBLAS_RETURN_IF_ERROR(hipblasGemmEx( + hipblas, + HIPBLAS_OP_N, HIPBLAS_OP_N, n, m, k, &alpha, - ldb_aligned == ldb ? b : b_padded.get(), rocblas_datatype_i8_r, ldb_aligned, - lda_aligned == lda ? a : a_padded.get(), rocblas_datatype_i8_r, lda_aligned, + ldb_aligned == ldb ? b : b_padded.get(), HIP_R_8I, ldb_aligned, + lda_aligned == lda ? a : a_padded.get(), HIP_R_8I, lda_aligned, &beta, - c, rocblas_datatype_i32_r, ldc, - c, rocblas_datatype_i32_r, ldc, // C == D - rocblas_datatype_i32_r, - rocblas_gemm_algo_standard, - 0, 0)); + c, HIP_R_32I, ldc, + HIPBLAS_COMPUTE_32I, + HIPBLAS_GEMM_DEFAULT)); return Status::OK(); } } // namespace rocm diff --git a/onnxruntime/core/providers/rocm/math/einsum.cc b/onnxruntime/core/providers/rocm/math/einsum.cc index 5ebe6fba26a6b..808ca2a31cc4e 100644 --- a/onnxruntime/core/providers/rocm/math/einsum.cc +++ b/onnxruntime/core/providers/rocm/math/einsum.cc @@ -31,8 +31,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vectorGetComputeStream(); ORT_RETURN_IF(!stream, "stream is null"); auto* rocm_stream = static_cast(stream); - rocblas_handle rocblas_handle = rocm_stream ? rocm_stream->rocblas_handle_ : nullptr; - EinsumOp::EinsumRocmAssets einsum_rocm_assets(rocblas_handle, rocm_ep_, stream, Info().GetAllocator(OrtMemType::OrtMemTypeDefault)); + hipblasHandle_t hipblas_handle = rocm_stream ? rocm_stream->hipblas_handle_ : nullptr; + EinsumOp::EinsumRocmAssets einsum_rocm_assets(hipblas_handle, rocm_ep_, stream, Info().GetAllocator(OrtMemType::OrtMemTypeDefault)); // EinsumComputePreprocessor section - auto einsum_compute_preprocessor = EinsumComputePreprocessor::Create(*einsum_equation_preprocessor_, inputs, allocator, diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h index 6be412348e6dd..c62e219a66499 100644 --- a/onnxruntime/core/providers/rocm/math/einsum.h +++ b/onnxruntime/core/providers/rocm/math/einsum.h @@ -15,7 +15,7 @@ namespace rocm { class Einsum final : public onnxruntime::Einsum { public: Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { - // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method + // We need to cast away the const as PerThreadHipblasHandle() is currently a non-const method // TODO: Clean up the ROCMExecutionProvider interface to avoid this rocm_ep_ = static_cast(info.GetExecutionProvider()); } @@ -30,7 +30,7 @@ class Einsum final : public onnxruntime::Einsum { using onnxruntime::Einsum::einsum_equation_preprocessor_; using onnxruntime::Einsum::equation_; - // We need to access to the ROCM EP instance to get the rocblas/miopen handles + // We need to access to the ROCM EP instance to get the hipblas/miopen handles const ROCMExecutionProvider* rocm_ep_; }; diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc index abf351335ce25..553fe1dccb332 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc @@ -37,7 +37,7 @@ Status Transpose(const gsl::span& permutation, const Tensor& input Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets) { return rocm::Transpose::DoTranspose(static_cast(einsum_rocm_assets)->rocm_ep_->GetDeviceProp(), static_cast(einsum_rocm_assets)->GetRocmStream(), - static_cast(einsum_rocm_assets)->rocblas_handle_, + static_cast(einsum_rocm_assets)->hipblas_handle_, permutation, input, output, input_shape_override); } @@ -54,7 +54,7 @@ Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, static_cast( static_cast(einsum_rocm_assets)->rocm_ep_->GetTuningContext()), static_cast(einsum_rocm_assets)->ort_stream_, - static_cast(einsum_rocm_assets)->rocblas_handle_, + static_cast(einsum_rocm_assets)->hipblas_handle_, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, N, M, K, /*alpha=*/1.0f, diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h index e1fc3f40ee9a5..689c65ae29f82 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h @@ -20,9 +20,9 @@ namespace EinsumOp { // Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow struct EinsumRocmAssets { - explicit EinsumRocmAssets(rocblas_handle rocblas_handle, + explicit EinsumRocmAssets(hipblasHandle_t hipblas_handle, const ROCMExecutionProvider* rocm_ep, - Stream* ort_stream, AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), + Stream* ort_stream, AllocatorPtr gpu_allocator) : hipblas_handle_(hipblas_handle), rocm_ep_(rocm_ep), ort_stream_(ort_stream), gpu_allocator_(gpu_allocator) {} @@ -31,7 +31,7 @@ struct EinsumRocmAssets { return ort_stream_ ? static_cast(ort_stream_->GetHandle()) : nullptr; } - rocblas_handle rocblas_handle_; + hipblasHandle_t hipblas_handle_; const ROCMExecutionProvider* rocm_ep_; Stream* ort_stream_; AllocatorPtr gpu_allocator_; diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu index 94bee88a469b3..e1c89a386dafc 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu @@ -1,4 +1,3 @@ -#include "hip/hip_runtime.h" // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. diff --git a/onnxruntime/core/providers/rocm/math/gemm.cc b/onnxruntime/core/providers/rocm/math/gemm.cc index bbd0a3037133c..529b48f736d50 100644 --- a/onnxruntime/core/providers/rocm/math/gemm.cc +++ b/onnxruntime/core/providers/rocm/math/gemm.cc @@ -85,9 +85,9 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { if (b_shape.Size() == 1) { // if B is (), (1,) or (1, 1), broadcast the scalar - ROCBLAS_RETURN_IF_ERROR(rocblasCopyHelper( + HIPBLAS_RETURN_IF_ERROR(hipblasCopyHelper( Stream(ctx), - GetRocblasHandle(ctx), + GetHipblasHandle(ctx), M * N, b_data, 0, @@ -96,7 +96,7 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { } else if (b_shape.NumDimensions() == 1 || b_shape[0] == 1) { // B is (N,) or (1, N), broadcast using Y(N,M) = 1 * B(N,1) x ones(1,M) + 0 * Y ORT_RETURN_IF_ERROR(tunable::blas::column_major::Gemm( - GetTuningContext(), ctx->GetComputeStream(), GetRocblasHandle(ctx), + GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), tunable::blas::BlasOp::NonTrans, tunable::blas::BlasOp::NonTrans, N, M, 1, @@ -108,7 +108,7 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y ORT_RETURN_IF_ERROR(tunable::blas::column_major::Gemm( - GetTuningContext(), ctx->GetComputeStream(), GetRocblasHandle(ctx), + GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), tunable::blas::BlasOp::NonTrans, tunable::blas::BlasOp::NonTrans, N, M, 1, @@ -125,7 +125,7 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { return tunable::blas::column_major::Gemm( GetTuningContext(), ctx->GetComputeStream(), - GetRocblasHandle(ctx), + GetHipblasHandle(ctx), trans_B_ ? BlasOp::Trans : BlasOp::NonTrans, trans_A_ ? BlasOp::Trans : BlasOp::NonTrans, N, M, K, diff --git a/onnxruntime/core/providers/rocm/math/matmul_impl.cc b/onnxruntime/core/providers/rocm/math/matmul_impl.cc index f9f33cb4c9725..e27a7e7575da7 100644 --- a/onnxruntime/core/providers/rocm/math/matmul_impl.cc +++ b/onnxruntime/core/providers/rocm/math/matmul_impl.cc @@ -78,11 +78,11 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, const int ldc = helper.Ldc(); int64_t stride_A, stride_B, stride_C, batch_count; - auto rocblas_handle = op->GetRocblasHandle(static_cast(stream)); + auto hipblasHandle_t = op->GetHipblasHandle(static_cast(stream)); if (helper.OutputOffsets().size() == 1) { return tunable::blas::column_major::Gemm( - op->GetTuningContext(), stream, rocblas_handle, + op->GetTuningContext(), stream, hipblasHandle_t, transB, transA, helper.N(), helper.M(), helper.K(), alpha, @@ -94,7 +94,7 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, transa, transb, trans_batch_a, trans_batch_b, stride_A, stride_B, stride_C, batch_count)) { return tunable::blas::column_major::StridedBatchedGemm( - op->GetTuningContext(), stream, rocblas_handle, + op->GetTuningContext(), stream, hipblasHandle_t, transB, transA, helper.N(), helper.M(), helper.K(), alpha, @@ -123,10 +123,10 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper, ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(stream)); ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(stream)); - // note that onnxruntime OrtValue is row major, while rocblas is column major, + // note that onnxruntime OrtValue is row major, while hipblas is column major, // so swap left/right operands return tunable::blas::column_major::BatchedGemm( - op->GetTuningContext(), stream, rocblas_handle, + op->GetTuningContext(), stream, hipblasHandle_t, transB, transA, helper.N(), helper.M(), helper.K(), alpha, diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 8d922d0bb4db1..a41934d38177d 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -157,7 +157,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { // Perform the transpose ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), - GetRocblasHandle(ctx), + GetHipblasHandle(ctx), permutation, *X, *temp_input)); transposed_input = std::move(temp_input); @@ -199,7 +199,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { // Perform the transpose to get the axes back to the original ordering ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), - GetRocblasHandle(ctx), + GetHipblasHandle(ctx), permutation, *intermediate_output, *Y)); } diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index d7f47d07a8fec..f99885634b6c7 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -324,7 +324,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) template Status Conv::ComputeInternal(OpKernelContext* context) const { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index bc9846203e57d..e6ebb5a380d3f 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -3,7 +3,7 @@ #pragma once -#include "core/platform/ort_mutex.h" +#include #include "core/providers/rocm/rocm_kernel.h" #include "core/providers/rocm/miopen_common.h" #include "core/providers/cpu/nn/conv_attributes.h" @@ -158,7 +158,7 @@ struct MiopenConvState { TensorShapeVector slice_axes; // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing - OrtMutex mutex; + std::mutex mutex; IAllocatorUniquePtr memory_for_miopen_conv_results; ~MiopenConvState() { diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 7447113fdf847..a6848e90b406d 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -66,7 +66,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy } { - std::lock_guard lock(s_.mutex); + std::lock_guard lock(s_.mutex); // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index a1f5eba9a24c8..d8b7e26d17b65 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -16,140 +16,34 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace rocm { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ + begin, end, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -348,7 +242,9 @@ Status ReduceKernel::ReduceKernelShared( // double* Y, // const TensorShape& output_shape, // miopenReduceTensorOp_t miopen_reduce_op, -// std::vector& output_dims) const; +// miopenHandle_t miopen_handle, +// onnxruntime::Stream* stream, +// TensorShapeVector& output_dims) const; template Status ReduceKernel::ReduceKernelShared( const float* X, @@ -387,7 +283,7 @@ Status PrepareForReduce(const Tensor* X, } const auto input_dims = input_shape.GetDims(); - InlinedVector reduced(rank, false); + std::vector reduced(rank, false); if (axes.size() > 0) { prepare_reduce_metadata.output_dims = input_shape.AsShapeVector(); for (auto axis : axes) { @@ -724,11 +620,35 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +// template Status ReduceComputeCore( +// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, +// gsl::span axes, +// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, +// Stream* ort_stream, +// const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { const Tensor* X = ctx->Input(0); - std::vector axes; + TensorShapeVector axes; size_t num_inputs = ctx->InputCount(); const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. @@ -904,7 +824,7 @@ template std::unique_ptr ReduceCompute axes, // bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, -// bool fast_reduction, const TensorShape* input_shape_override); +// bool fast_reduction, Stream* stream, const TensorShape* input_shape_override); template std::unique_ptr ReduceCompute( const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op, @@ -915,69 +835,75 @@ template std::unique_ptr ReduceCompute lock(lock_); + std::lock_guard lock(lock_); auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); @@ -80,7 +80,7 @@ void ROCMExternalAllocator::Free(void* p) { void* ROCMExternalAllocator::Reserve(size_t size) { void* p = Alloc(size); if (!p) return nullptr; - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); ORT_ENFORCE(reserved_.find(p) == reserved_.end()); reserved_.insert(p); return p; diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.h b/onnxruntime/core/providers/rocm/rocm_allocator.h index 04de09ab9c00b..ef13fc2e25cda 100644 --- a/onnxruntime/core/providers/rocm/rocm_allocator.h +++ b/onnxruntime/core/providers/rocm/rocm_allocator.h @@ -5,7 +5,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -42,7 +42,7 @@ class ROCMExternalAllocator : public ROCMAllocator { void* Reserve(size_t size) override; private: - mutable OrtMutex lock_; + mutable std::mutex lock_; ExternalAlloc alloc_; ExternalFree free_; ExternalEmptyCache empty_cache_; diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index ca12720fb3eb4..a73ef9b34b4de 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -33,7 +33,6 @@ const char* RocmErrString(hipError_t x) { template <> const char* RocmErrString(rocblas_status e) { ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - switch (e) { CASE_ENUM_TO_STR(rocblas_status_success); CASE_ENUM_TO_STR(rocblas_status_invalid_handle); @@ -53,6 +52,24 @@ const char* RocmErrString(rocblas_status e) { } } +template <> +const char* RocmErrString(hipblasStatus_t e) { + ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard + switch (e) { + CASE_ENUM_TO_STR(HIPBLAS_STATUS_SUCCESS); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_INITIALIZED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_ALLOC_FAILED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_INVALID_VALUE); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_ARCH_MISMATCH); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_MAPPING_ERROR); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_EXECUTION_FAILED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_INTERNAL_ERROR); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_SUPPORTED); + default: + return "(look for HIPBLAS_STATUS_xxx in hipblas_api.h)"; + } +} + template <> const char* RocmErrString(hiprandStatus_t) { ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard @@ -76,7 +93,7 @@ const char* RocmErrString(hipfftResult e) { CASE_ENUM_TO_STR(HIPFFT_SETUP_FAILED); CASE_ENUM_TO_STR(HIPFFT_INVALID_SIZE); default: - return "Unknown cufft error status"; + return "Unknown hipfft error status"; } } @@ -135,6 +152,8 @@ std::conditional_t RocmCall( template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); +template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); +template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); template Status RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); template void RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); template Status RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); @@ -151,9 +170,4 @@ template Status RocmCall(ncclResult_t retCode, const char* template void RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); #endif -#ifdef USE_HIPBLASLT -template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -#endif - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h index a8ddb85233031..4af1f40a6fccc 100644 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ b/onnxruntime/core/providers/rocm/rocm_common.h @@ -17,6 +17,7 @@ namespace rocm { #define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) #define ROCBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(ROCBLAS_CALL(expr)) +#define HIPBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLAS_CALL(expr)) #define HIPSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPSPARSE_CALL(expr)) #define HIPRAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPRAND_CALL(expr)) #define MIOPEN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(MIOPEN_CALL(expr)) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 7d741a6604679..0a427b146dcaa 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -10,6 +10,7 @@ #include "core/providers/rocm/rocm_fwd.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_profiler.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/rocm/rocm_contrib_kernels.h" @@ -40,11 +41,9 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? - // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, - // so we don't need the check here. - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, - Y->Location().device); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { @@ -89,12 +88,10 @@ class Memcpy final : public OpKernel { Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), - alloc); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, - *ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); Y->Add(std::move(*target_tensor)); } return Status::OK(); @@ -130,8 +127,7 @@ ONNX_OPERATOR_KERNEL_EX( AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo - external_allocator_info, + ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( @@ -153,8 +149,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -165,16 +160,13 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi } } -ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, - size_t /*gpu_mem_limit*/, - ArenaExtendStrategy /*arena_extend_strategy*/, - ROCMExecutionProviderExternalAllocatorInfo - /*external_allocator_info*/, +ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t /*gpu_mem_limit*/, + ArenaExtendStrategy /*arena_extend_strategy*/, ROCMExecutionProviderExternalAllocatorInfo /*external_allocator_info*/, OrtArenaCfg* /*default_memory_arena_cfg*/) { HIP_CALL_THROW(hipSetDevice(device_id)); - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(rocblas_handle_, stream)); + HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); + HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); @@ -183,36 +175,64 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { - ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_))); + ORT_IGNORE_RETURN_VALUE(HIPBLAS_CALL(hipblasDestroy(hipblas_handle_))); ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( + RocmGraphAnnotation_t hip_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(hip_graph_annotation_id)) { + return false; + } + if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { + return false; + } + return graph_id_to_run_count_.at(hip_graph_annotation_id) >= min_num_runs_before_hip_graph_capture_; +} + +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( + RocmGraphAnnotation_t hip_graph_annotation_id) const { + return hip_graph_.IsGraphCaptureAllowedOnRun(hip_graph_annotation_id); } -void ROCMExecutionProvider::PerThreadContext::CaptureBegin(int) { - hip_graph_.Reset(); - hip_graph_.CaptureBegin(0); +RocmGraphAnnotation_t ROCMExecutionProvider::PerThreadContext::GetRocmGraphAnnotationId( + const onnxruntime::RunOptions& run_options) const { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + // If graph annotation is not provided, fall back to the one hip graph per session behavior + RocmGraphAnnotation_t hip_graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, hip_graph_annotation_id), + "Failed to parse the hip graph annotation id: ", + *graph_annotation_str); + } + + return hip_graph_annotation_id; } -void ROCMExecutionProvider::PerThreadContext::CaptureEnd(int) { - hip_graph_.CaptureEnd(0); - is_graph_captured_ = true; +void ROCMExecutionProvider::PerThreadContext::CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id) { + hip_graph_.CaptureBegin(hip_graph_annotation_id); } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(int) const { - return is_graph_captured_; +void ROCMExecutionProvider::PerThreadContext::CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id) { + hip_graph_.CaptureEnd(hip_graph_annotation_id); } -Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(int graph_annotation_id) { - ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const { + return hip_graph_.IsGraphCaptured(graph_annotation_id); +} +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) { return hip_graph_.Replay(graph_annotation_id); } -void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - ++regular_run_count_before_graph_capture_; +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture( + RocmGraphAnnotation_t hip_graph_annotation_id) { + if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { + graph_id_to_run_count_[hip_graph_annotation_id] = 1; + return; + } + graph_id_to_run_count_[hip_graph_annotation_id]++; } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { @@ -237,8 +257,7 @@ void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { } ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, - info.device_id)}, + : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -282,7 +301,7 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in ROCMExecutionProvider::~ROCMExecutionProvider() { // clean up thread local context caches { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { const auto cache = cache_weak.lock(); if (!cache) continue; @@ -317,13 +336,12 @@ ROCMExecutionProvider::PerThreadContext& ROCMExecutionProvider::GetPerThreadCont // get context and update cache std::shared_ptr context; { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); // get or create a context if (context_state_.retired_context_pool.empty()) { context = std::make_shared(info_.device_id, stream_, info_.gpu_mem_limit, - info_.arena_extend_strategy, info_.external_allocator_info, - info_.default_memory_arena_cfg); + info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -351,7 +369,7 @@ void ROCMExecutionProvider::ReleasePerThreadContext() const { ORT_ENFORCE(cached_context); { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); context_state_.active_contexts.erase(cached_context); context_state_.retired_context_pool.push_back(cached_context); } @@ -364,26 +382,28 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && - !GetPerThreadContext().IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; - GetPerThreadContext().CaptureBegin(0); + RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { + LOGS(*GetLogger(), INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(hip_graph_annotation_id); } return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(0)) { - if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(0); +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id)) { + if (GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { + GetPerThreadContext().CaptureEnd(hip_graph_annotation_id); // HIP work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(0)); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(hip_graph_annotation_id)); } else { - GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(hip_graph_annotation_id); } } @@ -412,18 +432,19 @@ bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_hip_graph; } -bool ROCMExecutionProvider::IsGraphCaptured(int) const { - return GetPerThreadContext().IsGraphCaptured(0); +bool ROCMExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); } -Status ROCMExecutionProvider::ReplayGraph(int /*graph_annotation_id*/) { - return GetPerThreadContext().ReplayGraph(0); +Status ROCMExecutionProvider::ReplayGraph(int graph_annotation_id) { + return GetPerThreadContext().ReplayGraph(graph_annotation_id); } namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, Cos); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, Cos); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, Cos); @@ -482,8 +503,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, - LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); @@ -516,32 +536,20 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, - LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, - LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Add); @@ -597,8 +605,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, - Reciprocal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sqrt); @@ -612,18 +619,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, - BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); @@ -631,14 +632,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, - ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, - ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, - AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool); @@ -651,51 +649,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -720,6 +721,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); @@ -768,7 +770,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less); @@ -832,12 +833,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -851,45 +846,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -958,7 +914,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); @@ -967,26 +922,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); // OpSet 13 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow); @@ -1037,6 +982,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); @@ -1049,6 +995,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); @@ -1106,7 +1053,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); @@ -1126,6 +1072,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, U class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); @@ -1141,50 +1088,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); @@ -1236,6 +1169,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); + // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Relu); @@ -1280,16 +1220,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1313,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -1334,6 +1278,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1342,18 +1287,24 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -1369,52 +1320,81 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint32_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint64_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, bool, Cast); - -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - float, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - float, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - MLFloat16, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast); +// #endif + +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear); +// #endif +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear); +// #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Loop); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - float, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - float, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear); +// #endif +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear); +// #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); // Opset 20 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN); -// Opset 21 -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, - QuantizeLinear); +// Opset 21. +// TODO(fajin): support other quantized types +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear); +// #endif + +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear); +// #endif template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1427,6 +1407,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1632,51 +1613,55 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1814,15 +1799,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1836,45 +1815,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1948,27 +1888,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2011,7 +1942,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2019,6 +1949,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2031,6 +1962,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2086,6 +2018,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2120,62 +2053,43 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2219,6 +2133,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, @@ -2264,16 +2184,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2297,6 +2213,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2326,23 +2243,30 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, @@ -2358,11 +2282,23 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2370,26 +2306,58 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 21 + // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif }; for (auto& function_table_entry : function_table) { @@ -2446,6 +2414,26 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // It is not supported to pick the last index in case of encountering duplicate max values. + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} std::unique_ptr ROCMExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2454,6 +2442,9 @@ std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { InlinedVector candidates; + // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. + InlinedVector tentative_nodes; + const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2461,13 +2452,16 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const auto& node = *p_node; if (!node.GetExecutionProviderType().empty()) { + if (node.GetExecutionProviderType() == kRocmExecutionProvider) { + candidates.push_back(node.Index()); + } continue; } const KernelCreateInfo* rocm_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a ROCM kernel for this node if (rocm_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); continue; } @@ -2478,6 +2472,9 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, "GRU" == node.OpType()) { not_supported = true; force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside @@ -2485,9 +2482,10 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!force_inside && not_supported) { if (not_supported) { - LOGS_DEFAULT(WARNING) << "ROCM kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, WARNING) << "ROCM kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } } else { + tentative_nodes.push_back(node.Index()); candidates.push_back(node.Index()); } } @@ -2495,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) @@ -2519,7 +2517,8 @@ void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_, use_ep_level_unified_stream_, GetPerThreadContext().MiopenHandle(), - GetPerThreadContext().RocblasHandle()); + GetPerThreadContext().HipblasHandle(), + info_); } OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 6d6c05027e7bd..be467869248ea 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -8,7 +8,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/rocm/rocm_execution_provider_info.h" #include "core/providers/rocm/rocm_graph.h" #include "core/providers/rocm/rocm_pch.h" @@ -37,14 +37,20 @@ class ROCMExecutionProvider : public IExecutionProvider { return nullptr; } - rocblas_handle PerThreadDefaultRocblasHandle() { - return GetPerThreadContext().RocblasHandle(); + hipblasHandle_t PerThreadDefaultHipblasHandle() { + return GetPerThreadContext().HipblasHandle(); } miopenHandle_t PerThreadDefaultMiopenHandle() { return GetPerThreadContext().MiopenHandle(); } + hipStream_t ComputeStream() { + // this will return the ROCM EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return stream_; + } + template const T* GetConstOnes(size_t count, hipStream_t stream) { return GetPerThreadContext().template GetConstOnes(count, stream); @@ -75,8 +81,8 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured(int graph_annotation_id) const override; - Status ReplayGraph(int graph_annotation_id) override; + bool IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const override; + Status ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -98,9 +104,10 @@ class ROCMExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); - rocblas_handle RocblasHandle() const { - return rocblas_handle_; + hipblasHandle_t HipblasHandle() const { + return hipblas_handle_; } miopenHandle_t MiopenHandle() const { @@ -138,15 +145,17 @@ class ROCMExecutionProvider : public IExecutionProvider { } } - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - bool IsGraphCaptured(int graph_annotation_id) const; - Status ReplayGraph(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); + bool IsGraphCaptureAllowed(RocmGraphAnnotation_t hip_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(RocmGraphAnnotation_t hip_graph_annotation_id) const; + void CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id); + void CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id); + bool IsGraphCaptured(RocmGraphAnnotation_t hip_graph_annotation_id) const; + RocmGraphAnnotation_t GetRocmGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + Status ReplayGraph(RocmGraphAnnotation_t hip_graph_annotation_id); + void IncrementRegularRunCountBeforeGraphCapture(RocmGraphAnnotation_t hip_graph_annotation_id); private: - rocblas_handle rocblas_handle_ = nullptr; + hipblasHandle_t hipblas_handle_ = nullptr; miopenHandle_t miopen_handle_ = nullptr; std::unique_ptr> constant_ones_float_; @@ -157,8 +166,8 @@ class ROCMExecutionProvider : public IExecutionProvider { // Hip graph with multi threads will be supported in the future, so hip_graph_ // is put under PerThreadContext. ROCMGraph hip_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; // There is chance that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. @@ -196,7 +205,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::set, std::owner_less>> caches_to_update_on_destruction; // synchronizes access to PerThreadContextState members - OrtMutex mutex; + std::mutex mutex; }; // The execution provider maintains the PerThreadContexts in this structure. diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 70bf08d65401a..933a72122e7f9 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -89,22 +89,22 @@ class RocmKernel : public OpKernel { return stream->miopen_handle_; } - inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { - return GetRocblasHandle(static_cast(ctx->GetComputeStream())); + inline hipblasHandle_t GetHipblasHandle(OpKernelContext* ctx) const { + return GetHipblasHandle(static_cast(ctx->GetComputeStream())); } - static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { - return stream->rocblas_handle_; - } - - tunable::RocmTuningContext* GetTuningContext() const { - return static_cast(provider_->GetTuningContext()); + static inline hipblasHandle_t GetHipblasHandle(onnxruntime::RocmStream* stream) { + return stream->hipblas_handle_; } bool UseTF32() const { return false; } + tunable::RocmTuningContext* GetTuningContext() const { + return static_cast(provider_->GetTuningContext()); + } + // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished template @@ -169,14 +169,20 @@ class RocmKernel : public OpKernel { const RocmKernel* op_kernel_; }; - inline rocblas_handle DefaultRocblasHandle() const { - return provider_->PerThreadDefaultRocblasHandle(); + inline hipblasHandle_t DefaultHipblasHandle() const { + return provider_->PerThreadDefaultHipblasHandle(); } inline miopenHandle_t DefaultMiopenHandle() const { return provider_->PerThreadDefaultMiopenHandle(); } + inline hipStream_t DefaultHipStream() const { + // this will return the ROCM EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return provider_->ComputeStream(); + } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); return gpu_data_transfer->CopyTensorAsync(src, dst, stream); diff --git a/onnxruntime/core/providers/rocm/rocm_pch.h b/onnxruntime/core/providers/rocm/rocm_pch.h index 723b990c8d290..9713e41e126bb 100644 --- a/onnxruntime/core/providers/rocm/rocm_pch.h +++ b/onnxruntime/core/providers/rocm/rocm_pch.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #ifdef ORT_USE_NCCL diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index fdf64d07e0a6c..170a566d850b0 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,7 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; - info.enable_hip_graph = params->enable_hip_graph; + info.enable_hip_graph = params->enable_hip_graph != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index ef5689fc9a2d0..bbd1e1befccee 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(RocmStream& rocm_stream) : rocm_stream_(rocm_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->rocm_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->rocm_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->rocm_stream_.GetCpuAllocator()->Info(); + }; +} + struct RocmNotification : public synchronize::Notification { RocmNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); @@ -25,7 +44,8 @@ struct RocmNotification : public synchronize::Notification { void wait_on_device(Stream& device_stream) { ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); // launch a wait command to the rocm stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), + event_, 0)); }; void wait_on_host() { @@ -42,18 +62,21 @@ RocmStream::RocmStream(hipStream_t stream, bool release_cpu_buffer_on_rocm_stream, bool own_flag, miopenHandle_t external_miopen_handle, - rocblas_handle external_rocblas_handle) : Stream(stream, device), - own_stream_(own_flag), - cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream) { + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info) : Stream(stream, device), + own_stream_(own_flag), + cpu_allocator_(cpu_allocator), + release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream), + deferred_cpu_allocator_(*this), + ep_info_(ep_info) { if (own_flag) { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(rocblas_handle_, stream)); + HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); + HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); } else { - rocblas_handle_ = external_rocblas_handle; - ROCBLAS_CALL_THROW(rocblas_set_stream(rocblas_handle_, stream)); + hipblas_handle_ = external_hipblas_handle; + HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); miopen_handle_ = external_miopen_handle; MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); } @@ -62,7 +85,7 @@ RocmStream::RocmStream(hipStream_t stream, RocmStream::~RocmStream() { ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); if (own_stream_) { - rocblas_destroy_handle(rocblas_handle_); + hipblasDestroy(hipblas_handle_); miopenDestroy(miopen_handle_); auto* handle = GetHandle(); if (handle) @@ -149,8 +172,18 @@ void* RocmStream::GetResource(int version, int id) const { case RocmResource::miopen_handle_t: return reinterpret_cast(miopen_handle_); break; - case RocmResource::rocblas_handle_t: - return reinterpret_cast(rocblas_handle_); + case RocmResource::hipblas_handle_t: + return reinterpret_cast(hipblas_handle_); + break; + case RocmResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; + case RocmResource::device_id_t: + return reinterpret_cast(ep_info_.device_id); + break; + case RocmResource::arena_extend_strategy_t: + return reinterpret_cast(ep_info_.arena_extend_strategy); + break; break; default: break; @@ -174,25 +207,28 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis hipStream_t external_stream, bool use_existing_stream, miopenHandle_t external_miopen_handle, - rocblas_handle external_rocblas_handle) { + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info) { // wait rocm notification on rocm ep stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitRocmNotificationOnDevice); // wait rocm notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost); if (!use_existing_stream) - stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream](const OrtDevice& device) { + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, ep_info](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr); + // HIP_CALL_THROW(hipStreamCreate(&stream)); + return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr, ep_info); }); else stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, external_stream, external_miopen_handle, - external_rocblas_handle](const OrtDevice& device) { - return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_rocblas_handle); + external_hipblas_handle, + ep_info](const OrtDevice& device) { + return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle, ep_info); }); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 30983ce03568f..320fb4661e987 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -3,13 +3,21 @@ #pragma once #include "core/providers/rocm/rocm_pch.h" -// #include "core/providers/cuda/shared_inc/cuda_utils.h" +// #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" #include "core/framework/stream_handles.h" +#include "core/providers/rocm/rocm_execution_provider_info.h" namespace onnxruntime { + +struct RocmStream; void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(RocmStream&); + RocmStream& rocm_stream_; +}; + struct RocmStream : Stream { RocmStream(hipStream_t stream, const OrtDevice& device, @@ -17,7 +25,8 @@ struct RocmStream : Stream { bool release_cpu_buffer_on_rocm_stream, bool own_flag, miopenHandle_t external_miopen_handle, - rocblas_handle external_rocblas_handle); + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info); ~RocmStream(); @@ -33,16 +42,20 @@ struct RocmStream : Stream { miopenHandle_t miopen_handle_{}; - rocblas_handle rocblas_handle_{}; + hipblasHandle_t hipblas_handle_{}; void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; } private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_rocm_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; + const ROCMExecutionProviderInfo ep_info_; }; void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, @@ -52,5 +65,6 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis hipStream_t external_stream, bool use_existing_stream, miopenHandle_t external_miopen_handle, - rocblas_handle external_rocblas_handle); + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index d93f70785c093..675b30612065b 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -12,6 +12,41 @@ #else #define FLAG 0 #endif +// needed to work around calling rocblas API instead of hipblas API +static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op) { + switch (op) { + case HIPBLAS_OP_N: + return rocblas_operation_none; + case HIPBLAS_OP_T: + return rocblas_operation_transpose; + case HIPBLAS_OP_C: + return rocblas_operation_conjugate_transpose; + } + assert(0 && "HIPBLAS_STATUS_INVALID_ENUM"); +} +static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) { + switch (error) { + case rocblas_status_size_unchanged: + case rocblas_status_size_increased: + case rocblas_status_success: + return HIPBLAS_STATUS_SUCCESS; + case rocblas_status_invalid_handle: + return HIPBLAS_STATUS_NOT_INITIALIZED; + case rocblas_status_not_implemented: + return HIPBLAS_STATUS_NOT_SUPPORTED; + case rocblas_status_invalid_pointer: + case rocblas_status_invalid_size: + case rocblas_status_invalid_value: + return HIPBLAS_STATUS_INVALID_VALUE; + case rocblas_status_memory_error: + return HIPBLAS_STATUS_ALLOC_FAILED; + case rocblas_status_internal_error: + return HIPBLAS_STATUS_INTERNAL_ERROR; + default: + assert(0 && "ROCBLAS_STATUS_INVALID_ENUM"); + return HIPBLAS_STATUS_INTERNAL_ERROR; + } +} using namespace onnxruntime; @@ -22,6 +57,465 @@ inline int get_flag() { // Generalize library calls to be use in template functions +// hipblas + +// gemm +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* A, int lda, + const float* B, int ldb, + const float* beta, + float* C, int ldc) { + return hipblasGemmEx(handle, + transa, + transb, + m, n, k, + alpha, + A, HIP_R_32F, lda, + B, HIP_R_32F, ldb, + beta, + C, HIP_R_32F, ldc, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} + +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* A, int lda, + const double* B, int ldb, + const double* beta, + double* C, int ldc) { + return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const half* alpha, + const half* A, int lda, + const half* B, int ldb, + const half* beta, + half* C, int ldc) { + float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); + float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); + return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + A, rocblas_datatype_f16_r, lda, + B, rocblas_datatype_f16_r, ldb, + &h_b, + C, rocblas_datatype_f16_r, ldc, + C, rocblas_datatype_f16_r, ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); +} + +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const half* A, int lda, + const half* B, int ldb, + const float* beta, + half* C, int ldc) { + return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + alpha, + A, rocblas_datatype_f16_r, lda, + B, rocblas_datatype_f16_r, ldb, + beta, + C, rocblas_datatype_f16_r, ldc, + C, rocblas_datatype_f16_r, ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); +} + +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const half* A, int lda, + const half* B, int ldb, + const float* beta, + half* C, int ldc, + const hipDeviceProp_t&, + bool /*use_tf32*/) { + return hipblasGemmHelper(handle, + transa, + transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); +} + +inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + const BFloat16* B, int ldb, + const BFloat16* beta, + BFloat16* C, int ldc) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + + // accumulating in FP32 + return hipblasGemmEx(handle, + transa, + transb, + m, n, k, + &h_a, + A, HIP_R_16BF, lda, + B, HIP_R_16BF, ldb, + &h_b, + C, HIP_R_16BF, ldc, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} + +// Compatible for function call with extra arguments (see cublasGemmHelper) +template +hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const Scalar* alpha, + const Scalar* A, int lda, + const Scalar* B, int ldb, + const Scalar* beta, + Scalar* C, int ldc, + const hipDeviceProp_t&, + bool /*use_tf32*/) { + return hipblasGemmHelper(handle, + transa, + transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); +} + +// batched gemm +inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* Aarray[], int lda, + const float* Barray[], int ldb, + const float* beta, + float* Carray[], int ldc, + int batchCount) { + return hipblasGemmBatchedEx(handle, + transa, + transb, + m, n, k, + alpha, + (const void**)Aarray, HIP_R_32F, lda, + (const void**)Barray, HIP_R_32F, ldb, + beta, + (void**)Carray, HIP_R_32F, ldc, + batchCount, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} +inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* Aarray[], int lda, + const double* Barray[], int ldb, + const double* beta, + double* Carray[], int ldc, + int batchCount) { + return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); +} +inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const half* alpha, + const half* Aarray[], int lda, + const half* Barray[], int ldb, + const half* beta, + half* Carray[], int ldc, + int batchCount) { + float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); + float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); + return rocBLASStatusToHIPStatus(rocblas_gemm_batched_ex((rocblas_handle)handle, + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + (const void**)Aarray, rocblas_datatype_f16_r, lda, + (const void**)Barray, rocblas_datatype_f16_r, ldb, + &h_b, + (void**)Carray, rocblas_datatype_f16_r, ldc, + (void**)Carray, rocblas_datatype_f16_r, ldc, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); +} + +inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* Aarray[], int lda, + const BFloat16* Barray[], int ldb, + const BFloat16* beta, + BFloat16* Carray[], int ldc, + int batch_count) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + + // accumulating in FP32 + return hipblasGemmBatchedEx(handle, + transa, + transb, + m, n, k, + &h_a, + (const void**)Aarray, HIP_R_16BF, lda, + (const void**)Barray, HIP_R_16BF, ldb, + &h_b, + (void**)Carray, HIP_R_16BF, ldc, + batch_count, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} + +// strided batched gemm +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* A, int lda, + long long int strideA, + const float* B, int ldb, + long long int strideB, + const float* beta, + float* C, int ldc, + long long int strideC, + int batchCount) { + return hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, n, k, + alpha, + A, HIP_R_32F, lda, strideA, + B, HIP_R_32F, ldb, strideB, + beta, + C, HIP_R_32F, ldc, strideC, + batchCount, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} + +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* A, int lda, + long long int strideA, + const double* B, int ldb, + long long int strideB, + const double* beta, + double* C, int ldc, + long long int strideC, + int batchCount) { + return hipblasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const __half* alpha, + const __half* A, int lda, + long long int strideA, + const __half* B, int ldb, + long long int strideB, + const __half* beta, + __half* C, int ldc, + long long int strideC, + int batchCount) { + float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); + float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); + return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + A, rocblas_datatype_f16_r, lda, strideA, + B, rocblas_datatype_f16_r, ldb, strideB, + &h_b, + C, rocblas_datatype_f16_r, ldc, strideC, + C, rocblas_datatype_f16_r, ldc, strideC, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); +} + +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const __half* A, int lda, + intmax_t strideA, + const __half* B, int ldb, + intmax_t strideB, + const float* beta, + __half* C, int ldc, + intmax_t strideC, + int batchCount) { + return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + alpha, + A, rocblas_datatype_f16_r, lda, strideA, + B, rocblas_datatype_f16_r, ldb, strideB, + beta, + C, rocblas_datatype_f16_r, ldc, strideC, + C, rocblas_datatype_f16_r, ldc, strideC, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); +} + +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + intmax_t strideA, + const BFloat16* B, int ldb, + intmax_t strideB, + const BFloat16* beta, + BFloat16* C, int ldc, + intmax_t strideC, + int batch_count) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + // accumulating in FP32 + return hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, n, k, + &h_a, + A, HIP_R_16BF, lda, strideA, + B, HIP_R_16BF, ldb, strideB, + &h_b, + C, HIP_R_16BF, ldc, strideC, + batch_count, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); +} + +// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) +template +hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const Scalar* alpha, + const Scalar* A, int lda, + intmax_t strideA, + const Scalar* B, int ldb, + intmax_t strideB, + const Scalar* beta, + Scalar* C, int ldc, + intmax_t strideC, + int batchCount, + const hipDeviceProp_t&, + bool /*use_tf32*/) { + return hipblasGemmStridedBatchedHelper(handle, + transa, + transb, + m, n, k, + alpha, + A, lda, strideA, + B, ldb, strideB, + beta, + C, ldc, strideC, + batchCount); +} + +inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const __half* A, int lda, + intmax_t strideA, + const __half* B, int ldb, + intmax_t strideB, + const float* beta, + __half* C, int ldc, + intmax_t strideC, + int batchCount, + const hipDeviceProp_t&, + bool /*use_tf32*/) { + return hipblasGemmStridedBatchedHelper(handle, + transa, + transb, + m, n, k, + alpha, + A, lda, strideA, + B, ldb, strideB, + beta, + C, ldc, strideC, + batchCount); +} + +// transpose using geam +inline hipblasStatus_t hipblasTransposeHelper(hipStream_t /*stream*/, hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { + return hipblasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} +inline hipblasStatus_t hipblasTransposeHelper(hipStream_t /*stream*/, hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { + return hipblasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +inline bool CanUse_hipblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits. +hipblasStatus_t hipblasTransposeHelper(hipStream_t stream, hipblasHandle_t, hipblasOperation_t, hipblasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); + +// copy +inline hipblasStatus_t hipblasCopyHelper(hipStream_t /*stream*/, hipblasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { + return hipblasScopy(handle, n, x, incx, y, incy); +} +inline hipblasStatus_t hipblasCopyHelper(hipStream_t /*stream*/, hipblasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { + return hipblasDcopy(handle, n, x, incx, y, incy); +} +hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t handle, int n, const half* x, int incx, half* y, int incy); +hipblasStatus_t hipblasCopyHelper(hipStream_t stream, hipblasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); + +// rocblas + // gemm inline rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_operation transa, @@ -461,24 +955,3 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, ldc, strideC, batchCount); } - -// transpose using geam -inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { - return rocblas_sgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); -} -inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { - return rocblas_dgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); -} - -inline bool CanUse_rocblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits. -rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); - -// copy -inline rocblas_status rocblasCopyHelper(hipStream_t /*stream*/, rocblas_handle handle, int n, const float* x, int incx, float* y, int incy) { - return rocblas_scopy(handle, n, x, incx, y, incy); -} -inline rocblas_status rocblasCopyHelper(hipStream_t /*stream*/, rocblas_handle handle, int n, const double* x, int incx, double* y, int incy) { - return rocblas_dcopy(handle, n, x, incx, y, incy); -} -rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const half* x, int incx, half* y, int incy); -rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index 253ded1911cb5..563ae17fcdb3b 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,6 +17,7 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define HIPBLAS_CALL(expr) (RocmCall((expr), #expr, "HIPBLAS", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define ROCMSMI_CALL(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) @@ -28,6 +29,7 @@ std::conditional_t RocmCall( #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define HIPBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPBLAS", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define ROCMSMI_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.cu b/onnxruntime/core/providers/rocm/tunable/gemm.cu index b4b7eb47bed2f..f40440e55be9b 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm.cu +++ b/onnxruntime/core/providers/rocm/tunable/gemm.cu @@ -35,7 +35,7 @@ inline GEMM(T, ScalarT) { GemmParams params; params.tuning_ctx = tuning_ctx; params.stream = stream; - params.handle = handle; + params.handle = (rocblas_handle)handle; params.opa = opa; params.opb = opb; @@ -75,7 +75,7 @@ inline BATCHED_GEMM(T, ScalarT) { BatchedGemmParams params; params.tuning_ctx = tuning_ctx; params.stream = stream; - params.handle = handle; + params.handle = (rocblas_handle)handle; params.opa = opa; params.opb = opb; @@ -116,7 +116,7 @@ inline STRIDED_BATCHED_GEMM(T, ScalarT) { StridedBatchedGemmParams params; params.tuning_ctx = tuning_ctx; params.stream = stream; - params.handle = handle; + params.handle = (rocblas_handle)handle; params.opa = opa; params.opb = opb; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm.h b/onnxruntime/core/providers/rocm/tunable/gemm.h index c124506f6f988..5b06535cb3862 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm.h @@ -14,32 +14,32 @@ namespace blas { #define GEMM(T, ScalarT) \ common::Status Gemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, rocblas_handle handle, \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ BlasOp opa, BlasOp opb, \ std::int64_t m, std::int64_t n, std::int64_t k, \ ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ ScalarT beta, T* c, std::int64_t ldc) -#define BATCHED_GEMM(T, ScalarT) \ - common::Status BatchedGemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, rocblas_handle handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, \ - const T** as, std::int64_t lda, \ - const T** bs, std::int64_t ldb, \ - ScalarT beta, \ +#define BATCHED_GEMM(T, ScalarT) \ + common::Status BatchedGemm( \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, \ + const T** as, std::int64_t lda, \ + const T** bs, std::int64_t ldb, \ + ScalarT beta, \ T** cs, std::int64_t ldc, std::int64_t batch) -#define STRIDED_BATCHED_GEMM(T, ScalarT) \ - common::Status StridedBatchedGemm( \ - RocmTuningContext* tuning_ctx, Stream* stream, rocblas_handle handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, \ - const T* a, std::int64_t lda, std::int64_t stride_a, \ - const T* b, std::int64_t ldb, std::int64_t stride_b, \ - ScalarT beta, \ +#define STRIDED_BATCHED_GEMM(T, ScalarT) \ + common::Status StridedBatchedGemm( \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, \ + const T* a, std::int64_t lda, std::int64_t stride_a, \ + const T* b, std::int64_t ldb, std::int64_t stride_b, \ + ScalarT beta, \ T* c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch) namespace row_major { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index 6554ed977cef6..486ce5bfb731a 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -37,26 +37,26 @@ enum ActivationType { }; template -constexpr hipblasltDatatype_t HipBlasDataTypeFor(); +constexpr hipDataType HipBlasDataTypeFor(); template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_32F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_32F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16B; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16BF; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_64F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_64F; } template @@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp hipblasOperation_t trans_a = MapBlasOpToHipBlasLt(); hipblasOperation_t trans_b = MapBlasOpToHipBlasLt(); - hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor(); + hipDataType in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle, @@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp in_out_datatype, in_out_datatype, in_out_datatype, - HIPBLASLT_COMPUTE_F32, + HIPBLAS_COMPUTE_32F, heuristic_result)); HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle)); @@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); int batch = GetBatchCountFromParams(params); if (batch > 1) { diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h b/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h index 580f465c4926b..95fa4f37d7f68 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h @@ -4,7 +4,6 @@ #pragma once #include -#include #include "core/providers/rocm/rocm_common.h" // avoid provider_api.h ODR violation #include "core/framework/tunable.h" @@ -22,7 +21,6 @@ template using Op = Op; class Timer; - template using TunableOp = TunableOp; diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index 05cdc82e90564..88e5fde189ba2 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -42,26 +42,6 @@ static Status ValidateRocBlasVersion(const std::string& value) { return Status::OK(); } -std::string RocmTuningResultsValidator::GetDeviceModel() const { - return ep_->GetDeviceProp().name; -} - -Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { - auto current = GetDeviceModel(); - ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, - ", onnxruntime currently run with device ", current); - return Status::OK(); -} - -RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { - RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); - RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); - RegisterValidator( - "DEVICE_MODEL", - [this]() { return GetDeviceModel(); }, - [this](const std::string& value) { return ValidateDeviceModel(value); }); -} - std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { std::ostringstream oss; #ifdef USE_COMPOSABLE_KERNEL @@ -87,6 +67,26 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { return oss.str(); } +std::string RocmTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { + RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); + RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : ITuningContext(ep), info_(info), validator_(ep) {} diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..45f81ed22b7f7 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -294,7 +294,8 @@ std::unique_ptr CreateGPUDataTransfer(); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); @@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { namespace QDQ { inline std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer* graph_viewer) { - return g_host->QDQ__GetAllNodeUnits(graph_viewer); +GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { + return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } } // namespace QDQ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d3b12f9728135..aa8c367d25d51 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { - return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) { + return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 34319287a80fd..d182d0b9173bd 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -202,7 +202,8 @@ struct ProviderHost { virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) = 0; + gsl::span tentative_nodes, + const logging::Logger& logger) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0; @@ -389,6 +390,7 @@ struct ProviderHost { virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; virtual void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) = 0; virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual std::string* AttributeProto__release_s(ONNX_NAMESPACE::AttributeProto* p) = 0; // GraphProto virtual std::unique_ptr GraphProto__construct() = 0; @@ -578,6 +580,8 @@ struct ProviderHost { // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; + virtual std::string ConfigOptions__GetConfigOrDefault(const ConfigOptions* p, const std::string& config_key, + const std::string& default_value) = 0; // OrtRunOptions virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0; @@ -888,7 +892,7 @@ struct ProviderHost { virtual std::unique_ptr NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0; virtual std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0; + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0; // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, @@ -958,7 +962,7 @@ struct ProviderHost { // GraphViewer virtual void GraphViewer__operator_delete(GraphViewer* p) = 0; - virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0; + virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger, const ModelMetaData&) = 0; virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; @@ -994,6 +998,7 @@ struct ProviderHost { bool include_outer_scope_args, int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; + virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 4644f703dcb5d..54249f0864cd7 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -122,6 +122,7 @@ struct AttributeProto final { void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } void set_type(AttributeProto_AttributeType value) { return g_host->AttributeProto__set_type(this, value); } TensorProto* add_tensors() { return g_host->AttributeProto__add_tensors(this); } + std::string* release_s() { return g_host->AttributeProto__release_s(this); } typedef AttributeProto_AttributeType AttributeType; static constexpr AttributeType UNDEFINED = AttributeProto_AttributeType_UNDEFINED; @@ -485,6 +486,10 @@ struct ConfigOptions final { return g_host->ConfigOptions__GetConfigEntry(this, config_key); } + std::string GetConfigOrDefault(const std::string& config_key, const std::string& default_value) const { + return g_host->ConfigOptions__GetConfigOrDefault(this, config_key, default_value); + } + PROVIDER_DISALLOW_ALL(ConfigOptions) }; @@ -1018,11 +1023,13 @@ struct Graph final { PROVIDER_DISALLOW_ALL(Graph) }; +using ModelMetaData = std::unordered_map; + class GraphViewer final { public: static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } - std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } + std::unique_ptr CreateModel(const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) const { return g_host->GraphViewer__CreateModel(this, logger, metadata); } const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } @@ -1064,6 +1071,7 @@ class GraphViewer final { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index ef45d6c85d6a9..fbccd7d4a286b 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -128,7 +128,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, // Serialize modelproto to string auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto model = new_graph_viewer->CreateModel(*logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model = new_graph_viewer->CreateModel(*logger, metadata); auto model_proto = model->ToProto(); new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c3d010ac9fcd7..1b432dad44263 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -452,9 +452,9 @@ TensorrtLogger& GetTensorrtLogger(bool verbose_log) { return trt_logger; } -std::unique_lock TensorrtExecutionProvider::GetApiLock() const { - static OrtMutex singleton; - return std::unique_lock(singleton); +std::unique_lock TensorrtExecutionProvider::GetApiLock() const { + static std::mutex singleton; + return std::unique_lock(singleton); } /* @@ -1236,7 +1236,7 @@ void TensorrtExecutionProvider::ReleasePerThreadContext() const { ORT_ENFORCE(cached_context); { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); context_state_.active_contexts.erase(cached_context); context_state_.retired_context_pool.push_back(cached_context); } @@ -1258,7 +1258,7 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh // get context and update cache std::shared_ptr context; { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); // get or create a context if (context_state_.retired_context_pool.empty()) { @@ -1725,6 +1725,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } + trt_version_ = getInferLibVersion(); + CUDA_CALL_THROW(cudaRuntimeGetVersion(&cuda_version_)); + + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] CUDA version is " << cuda_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " << "device_id: " << device_id_ << ", trt_max_partition_iterations: " << max_partition_iterations_ @@ -1768,7 +1774,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv TensorrtExecutionProvider::~TensorrtExecutionProvider() { // clean up thread local context caches { - std::lock_guard lock(context_state_.mutex); + std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { const auto cache = cache_weak.lock(); if (!cache) continue; @@ -1948,7 +1954,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph // Find inputs and outputs of the subgraph std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; int output_order = 0; @@ -2040,12 +2046,25 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; + + // Get the input order of the original graph + int order = 0; + for (const auto* input : graph.GetInputs()) { + original_inputs[input] = order++; + } + + // input order needs to be consistent with original graph's input order for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + const auto& iter = original_inputs.find(it->first); + if (iter != original_inputs.end()) { + inputs.insert(std::pair(iter->second, iter->first)); + } else { + inputs.insert(std::pair(it->second, it->first)); + } } + // Sort outputs by the order they were added for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { outputs.insert(std::pair(it->second, it->first)); } @@ -2449,23 +2468,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // So, simply return the ComputeCapability here. if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) { SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)), 0); result.push_back(ComputeCapability::Create(std::move(sub_graph))); return result; } // Generate unique kernel name for TRT graph - HashValue model_hash = TRTGenerateId(graph); + HashValue model_hash = TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)); // Get supported node list from TensorRT parser const int number_of_ort_nodes = graph.NumberOfNodes(); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - std::vector filtered_nodes_vector; + std::set exclude_ops_set; + + /* + * There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. + * TRT EP automatically excludes DDS ops from running on TRT. + */ + if (trt_version_ >= 100000 && trt_version_ < 110000) { + exclude_ops_set.insert("NonMaxSuppression"); + exclude_ops_set.insert("NonZero"); + exclude_ops_set.insert("RoiAlign"); + LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable"; + } + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + bool new_subgraph = true; + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. It's a DDS op. + */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); + bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: * @@ -2477,29 +2516,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. */ if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - auto sub_graphs = node->GetSubgraphs(); - if (sub_graphs.size() != 0) { - bool all_subgraphs_are_supported = true; - for (auto sub_graph : sub_graphs) { - // TRT EP should consider the empty subgraph is fully supported by TRT. - if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { - continue; - } - if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { - all_subgraphs_are_supported = false; - break; + auto supported_control_flow_op = [&](const Node* node) { + auto sub_graphs = node->GetSubgraphs(); + if (sub_graphs.size() != 0) { + for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } + if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } } } - if (!all_subgraphs_are_supported) { - // if not all its subgraphs are supported, we need to exclude this control flow op - continue; - } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; } - filtered_nodes_vector.push_back(index); } - SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; bool early_termination = false; supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { @@ -2938,14 +2991,28 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Check platform availability for low precision if (fp16_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if (!trt_builder->platformHasFastFp16()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif fp16_enable_ = false; LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; } } if (int8_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if (!trt_builder->platformHasFastInt8()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif int8_enable_ = false; LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; } @@ -3161,12 +3228,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } } else { - // Set INT8 per tensor dynamic range - if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) #endif + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { trt_config->setInt8Calibrator(nullptr); #if defined(_MSC_VER) #pragma warning(pop) @@ -3416,7 +3483,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; @@ -3573,13 +3640,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView for (auto trt_profile : trt_profiles) { trt_config->addOptimizationProfile(trt_profile); } - - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) #endif + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { trt_config->setInt8Calibrator(nullptr); #if defined(_MSC_VER) #pragma warning(pop) @@ -4086,7 +4152,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // The whole compute_function should be considered the critical section. // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 97c9367b0bb61..d3e0b0fba8891 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -12,7 +12,7 @@ typedef void* cudnnStatus_t; #endif #include "core/providers/tensorrt/nv_includes.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/cuda/cuda_graph.h" #include "tensorrt_execution_provider_info.h" @@ -169,7 +169,7 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; - OrtMutex* tensorrt_mu_ptr = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; @@ -214,7 +214,7 @@ struct TensorrtShortFuncState { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; - OrtMutex* tensorrt_mu_ptr = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; }; // Holds important information for building valid ORT graph. @@ -312,7 +312,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::string tactic_sources_; std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; std::unique_ptr runtime_ = nullptr; - OrtMutex tensorrt_mu_; + std::mutex tensorrt_mu_; int device_id_; std::string compute_capability_; bool context_memory_sharing_enable_ = false; @@ -329,6 +329,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool cuda_graph_enable_ = false; std::string cache_prefix_; bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; + + // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH + int32_t trt_version_; + int32_t cuda_version_; // The OrtAllocator object will be get during ep compute time // and should be kept for the lifetime of TRT EP object. @@ -476,7 +481,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::set, std::owner_less>> caches_to_update_on_destruction; // synchronizes access to PerThreadContextState members - OrtMutex mutex; + std::mutex mutex; }; // The execution provider maintains the PerThreadContexts in this structure. @@ -509,7 +514,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) should be protected by a lock when invoked by multiple threads concurrently. */ - std::unique_lock GetApiLock() const; + std::unique_lock GetApiLock() const; /**Check the graph is the subgraph of control flow op*/ bool IsSubGraphOfControlFlowOp(const GraphViewer& graph) const; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index a4d2d6c9d65f3..e93d3565fe33d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -28,8 +28,8 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose); common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { static std::unique_ptr custom_op_domain = std::make_unique(); static std::vector> created_custom_op_list; - static OrtMutex mutex; - std::lock_guard lock(mutex); + static std::mutex mutex; + std::lock_guard lock(mutex); if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { domain_list.push_back(custom_op_domain.get()); return Status::OK(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index 95abcd1bad2b8..5a7b135fd92cd 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -520,7 +520,7 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) { * compiled kernels, so the name must be unique and deterministic across models and sessions. * */ -HashValue TRTGenerateId(const GraphViewer& graph_viewer) { +HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) { HashValue model_hash = 0; // find the top level graph @@ -583,12 +583,11 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer) { #endif #ifdef CUDA_VERSION - hash_str(std::to_string(CUDA_VERSION)); + hash_str(cuda_version); #endif #if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) - std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); - hash_str(TRT_VERSION); + hash_str(trt_version); #endif model_hash = hash[0] | (uint64_t(hash[1]) << 32); diff --git a/onnxruntime/core/providers/tvm/custom_logging.cc b/onnxruntime/core/providers/tvm/custom_logging.cc deleted file mode 100644 index 1cabe81f8e87e..0000000000000 --- a/onnxruntime/core/providers/tvm/custom_logging.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// -// Enable custom logging - this will cause TVM to use a custom implementation -// of tvm::runtime::detail::LogMessage. We use this to change the absolute -// file path to relative file path. - -#include -#include -#include -#include -#include -#include - -// TODO(agladyshev): Make conditional choice of sep for Windows and UNIX -std::string GetFileName(const std::string& file_path, char sep = '/') { - return {std::next(file_path.begin(), file_path.find_last_of(sep) + 1), - file_path.end()}; -} - -std::string GetTimedLogMessage(const std::string& file, int lineno, const std::string& message) { - std::stringstream sstream; - std::string file_name = GetFileName(file); - std::time_t t = std::time(nullptr); - sstream << "[" -#ifdef _WIN32 -// TODO(vvchernov): use #include instead of and localtime_s() approach for WIN32 -#pragma warning(disable : 4996) // _CRT_SECURE_NO_WARNINGS -#endif - << std::put_time(std::localtime(&t), "%H:%M:%S") -#ifdef _WIN32 -#pragma warning(default : 4996) -#endif - << "][TVM] " - << file_name << ":" << lineno << ": " + message; - return sstream.str(); -} - -namespace tvm { -namespace runtime { -namespace detail { -void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { - throw std::runtime_error(GetTimedLogMessage(file, lineno, message)); -} - -void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { - std::cerr << GetTimedLogMessage(file, lineno, message) << std::endl; -} - -} // namespace detail -} // namespace runtime -} // namespace tvm diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher.cc b/onnxruntime/core/providers/tvm/hash_alg/hasher.cc deleted file mode 100644 index bb62b41c7aa85..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" - -#include "hasher.h" // NOLINT(build/include_subdir) -#include "hasher_impl.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -Hasher::Hasher(const std::string& hash_type) { - hasher_ = getHasherImpl(hash_type); -} - -std::string Hasher::hash(const char* src, size_t size) const { - return hasher_->hash(src, size); -} - -std::shared_ptr Hasher::getHasherImpl(const std::string& hash_type) { - if (hash_type == "sha256") { - return std::make_shared(); - } else { - ORT_NOT_IMPLEMENTED("Hasher was not implemented for hash type: ", hash_type); - } - return nullptr; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher.h b/onnxruntime/core/providers/tvm/hash_alg/hasher.h deleted file mode 100644 index 7b2f50def2e36..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ - -#include -#include - -namespace onnxruntime { -namespace tvm { -class HasherImpl; - -class Hasher { - public: - Hasher() = delete; - explicit Hasher(const std::string& hash_type); - virtual ~Hasher() = default; - - std::string hash(const char* src, size_t size) const; - - private: - std::shared_ptr getHasherImpl(const std::string& hash_type); - - private: - std::shared_ptr hasher_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc b/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc deleted file mode 100644 index 20aef66f3046a..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "hasher_impl.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -std::string HasherSHA256Impl::hash(const char* src, size_t size) const { - return hexdigest(src, size); -} - -void HasherSHA256Impl::digest(const Ipp8u* src, int size, Ipp8u* dst) { - IppStatus status = ippStsNoErr; - const IppsHashMethod* hashMethod = ippsHashMethod_SHA256(); - status = ippsHashMessage_rmf(src, size, dst, hashMethod); - if (ippStsNoErr != status) { - ORT_THROW("Can't get SHA-256..."); - } -} - -std::string HasherSHA256Impl::digest(const char* src, size_t size) { - const int digest_size_byte = IPP_SHA256_DIGEST_BITSIZE / 8; - auto dst = std::unique_ptr(new char[digest_size_byte]); - digest(reinterpret_cast(src), static_cast(size), reinterpret_cast(dst.get())); - return std::string(dst.get(), digest_size_byte); -} - -std::string HasherSHA256Impl::hexdigest(const char* src, size_t size) { - std::string byte_digest = digest(src, size); - std::stringstream ss; - for (char c : byte_digest) { - ss << std::hex << std::setw(2) << std::setfill('0') << (0xff & c); - } - return ss.str(); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h b/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h deleted file mode 100644 index 6c285dd0c78f3..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ - -#include -#include -#include -#include -#include - -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm { - -class HasherImpl { - public: - HasherImpl() = default; - virtual ~HasherImpl() = default; - - virtual std::string hash(const char* src, size_t size) const = 0; -}; - -class HasherSHA256Impl : public HasherImpl { - public: - HasherSHA256Impl() = default; - virtual ~HasherSHA256Impl() = default; - - std::string hash(const char* src, size_t size) const final; - - private: - static void digest(const Ipp8u* src, int size, Ipp8u* dst); - static std::string digest(const char* src, size_t size); - static std::string hexdigest(const char* src, size_t size); -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ diff --git a/onnxruntime/core/providers/tvm/symbols.txt b/onnxruntime/core/providers/tvm/symbols.txt deleted file mode 100644 index 8d903acd9ea76..0000000000000 --- a/onnxruntime/core/providers/tvm/symbols.txt +++ /dev/null @@ -1 +0,0 @@ -OrtSessionOptionsAppendExecutionProvider_Tvm diff --git a/onnxruntime/core/providers/tvm/tvm_allocator.cc b/onnxruntime/core/providers/tvm/tvm_allocator.cc deleted file mode 100644 index 4b68f6432e8cc..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_allocator.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "tvm_allocator.h" -#include "core/framework/session_state.h" -#include "xpu_data_transfer.h" - -namespace onnxruntime { -namespace tvm { - -void* TVMAllocator::Alloc(size_t size) { - void* p = nullptr; - if (size > 0) { - DLDataType dl_type{kDLInt, 8, 1}; - int err = TVMDeviceAllocDataSpace(ctx, size, ::tvm::runtime::kAllocAlignment, dl_type, reinterpret_cast(&p)); - CHECK_EQ(err, 0); - return p; - } - return p; -} - -void TVMAllocator::Free(void* p) { - TVMDeviceFreeDataSpace(ctx, p); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_allocator.h b/onnxruntime/core/providers/tvm/tvm_allocator.h deleted file mode 100644 index f3ba544b8ac46..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_allocator.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_ALLOCATOR -#define TVM_ALLOCATOR - -#include "core/framework/allocator.h" -#include "tvm_common.h" - -namespace onnxruntime { -namespace tvm { - -#define TVM_ALLOC_ALIGN 128 - -class TVMAllocator : public IAllocator { - public: - TVMAllocator() : TVMAllocator(OrtMemoryInfo("TVM", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, - OrtMemTypeDefault)) {} - explicit TVMAllocator(const OrtMemoryInfo& info) - : IAllocator(info) { - switch (info.device.Type()) { - case OrtDevice::CPU: - ctx = {kDLCPU, info.device.Id()}; - break; - case OrtDevice::GPU: - ctx = {kDLVulkan, info.device.Id()}; - break; - default: - ORT_NOT_IMPLEMENTED("Unsupported device"); - break; - } - } - - virtual void* Alloc(size_t size) override; - virtual void Free(void* p) override; - DLDevice ctx; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_ALLOCATOR diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc deleted file mode 100644 index e9a7d002e77c8..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef _WIN32 -#include -#else -#include // glob(), globfree() -#endif -#include // memset() -#include -#include -#include - -#include -#include -#include - -#include "core/common/common.h" -#include - -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -using TvmIntArray = ::tvm::Array<::tvm::Integer>; -using TvmPackedFunc = ::tvm::PackedFunc; -namespace tvm_rt = ::tvm::runtime; -namespace tvm_rt_vm = tvm_rt::vm; - -TvmModule TVMCompile(const TvmEPOptions& options, - const std::string& onnx_txt, - const std::string& model_path, - int opset, - const TVMTensorShapes& input_shapes) { - ::tvm::Array shapes; - for (size_t i = 0; i < input_shapes.size(); ++i) { - TvmIntArray shape; - for (auto& dim : input_shapes[i]) { - shape.push_back(::tvm::Integer(dim)); - } - shapes.push_back(shape); - } - - const TvmPackedFunc* compile = tvm_rt::Registry::Get("tvm_onnx_import_and_compile"); - ORT_ENFORCE(compile != nullptr, "Unable to retrieve 'tvm_onnx_import_and_compile'."); - TvmModule mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, - model_path, - options.executor, - options.target, - options.target_host, - options.opt_level, - opset, - options.freeze_weights, - shapes, - options.to_nhwc, - options.tuning_file_path, - options.tuning_type); - ORT_ENFORCE(mod.get() != nullptr, "Compiled TVM Module is nullptr!"); - return mod; -} - -std::vector glob(const std::string& dir, const std::string& extension) { - std::vector filenames; -#ifdef _WIN32 - std::string pattern = dir + "/*." + extension; - WIN32_FIND_DATA fd; - HANDLE hFind = ::FindFirstFile(pattern.c_str(), &fd); - if (hFind != INVALID_HANDLE_VALUE) { - do { - if (!(fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - filenames.push_back( - dir + - ToUTF8String(PathString{k_preferred_path_separator}) + - fd.cFileName); - } - } while (::FindNextFile(hFind, &fd)); - ::FindClose(hFind); - } -#else - glob_t glob_result; - memset(&glob_result, 0, sizeof(glob_result)); - - std::string pattern = dir + "/*." + extension; - int return_value = glob(pattern.c_str(), GLOB_TILDE, NULL, &glob_result); - ORT_ENFORCE(return_value == 0, "No results of glob for pattern: " + pattern); - - for (size_t i = 0; i < glob_result.gl_pathc; ++i) { - filenames.push_back(std::string(glob_result.gl_pathv[i])); - } - globfree(&glob_result); -#endif - return filenames; -} - -std::string filter_lib_paths(const std::vector& lib_paths, const std::string& lib_ext) { - std::string lib_path; - size_t counter = 0; - for (const auto& path : lib_paths) { - if (path.find("libtvm_runtime." + lib_ext) != std::string::npos || - path.find("liboctomized_model." + lib_ext) != std::string::npos) { - ++counter; - } else { - lib_path = path; - } - } - ORT_ENFORCE((lib_paths.size() - counter) == 1, "It should be only one shared library for model after filtering"); - - return lib_path; -} - -static std::unordered_map str2dev_type = { - {"llvm", 1}, - {"stackvm", 1}, - {"cpu", 1}, - {"c", 1}, - {"hybrid", 1}, - {"composite", 1}, - {"cuda", 2}, - {"nvptx", 2}, - {"cl", 4}, - {"opencl", 4}, - {"sdaccel", 4}, - {"aocl", 5}, - {"aocl_sw_emu", 5}, - {"vulkan", 7}, - {"metal", 8}, - {"vpi", 9}, - {"rocm", 10}, - {"ext_dev", 12}, - {"hexagon", 14}, - {"webgpu", 15}}; - -TvmModule TVMSoCompile(const TvmEPOptions& options) { - const std::string& dir = options.so_folder; -#ifdef _WIN32 - std::string lib_ext = "dll"; -#else - std::string lib_ext = "so"; -#endif - const std::string lib_path = filter_lib_paths(glob(dir, lib_ext), lib_ext); - const std::string consts_path = dir + - ToUTF8String(PathString{k_preferred_path_separator}) + - "consts"; - const auto& ro_paths = glob(dir, "ro"); - ORT_ENFORCE(ro_paths.size() == 1, "It should be only one ro file in folder: " + dir); - const std::string vm_exec_code_path = ro_paths[0]; - - TvmModule lib = TvmModule::LoadFromFile(lib_path); - - std::ifstream code(vm_exec_code_path, std::ios::binary); - std::stringstream ss; - ss << code.rdbuf(); - - auto exec_mod = tvm_rt_vm::Executable::Load(ss.str(), lib); - const tvm_rt_vm::Executable* tmp = exec_mod.as(); - auto exec = tvm_rt::GetObjectPtr(const_cast(tmp)); - exec->LoadLateBoundConstantsFromFile(consts_path); - - auto vm = tvm_rt::make_object(); - vm->LoadExecutable(exec); - - size_t pos = options.target.find(" "); - const std::string dev_type_str = options.target.substr(0, pos); - ORT_ENFORCE(!dev_type_str.empty(), "Device was not found in target string"); - uint64_t dev_type = str2dev_type[dev_type_str]; - const uint64_t cpu_type = str2dev_type["cpu"]; - // Initialize the VM for the specified device. If the device is not a CPU, - // We'll need to add a CPU context to drive it. - int arity; - if (dev_type == cpu_type) { - arity = 3; - } else { - arity = 6; - } - uint64_t alloc_type = uint64_t(tvm_rt_vm::AllocatorType::kPooled); - // TODO(vchernov): multiple devices using and using device with specified id are not supported - // Always use the first device of the specified type. - uint64_t device_id = 0; - std::vector init_vals(arity); - std::vector codes(arity); - tvm_rt::TVMArgsSetter setter(init_vals.data(), codes.data()); - setter(0, dev_type); - setter(1, device_id); - setter(2, alloc_type); - // Also initialize a CPU device context. - if (dev_type != cpu_type) { - setter(3, cpu_type); - setter(4, device_id); - setter(5, alloc_type); - } - tvm_rt::TVMRetValue rv; - // Call the packed func with the init arguments. - vm->GetFunction("init", nullptr).CallPacked(tvm_rt::TVMArgs(init_vals.data(), codes.data(), arity), &rv); - - return TvmModule(vm); -} - -void TVMSetInputs(TvmModule& mod, - std::vector& inds, - std::vector& inputs) { - TvmPackedFunc set_input = mod.GetFunction("set_input", false); - TvmPackedFunc set_input_zero_copy = mod.GetFunction("set_input_zero_copy", false); - for (size_t i = 0; i < inds.size(); ++i) { - if (reinterpret_cast(inputs[i].data) % tvm_rt::kAllocAlignment == 0) { - set_input_zero_copy(inds[i], &inputs[i]); - } else { - set_input(inds[i], &inputs[i]); - } - } -} - -void TVM_VM_SetInputs(TvmModule& mod, - std::vector& inds, - std::vector& inputs) { - size_t num_total_args = inputs.size() + 1; - std::vector tvm_values(num_total_args); - std::vector tvm_type_codes(num_total_args); - ::tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - const std::string func_name = "main"; - setter(0, func_name.c_str()); - for (size_t k = 0; k < num_total_args - 1; ++k) { - setter(inds[k] + 1, &inputs[k]); - } - - TvmPackedFunc set_input = mod.GetFunction("set_input", false); - ::tvm::runtime::TVMRetValue rv; - set_input.CallPacked(::tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), gsl::narrow_cast(num_total_args)), &rv); -} - -void TVMSetOutputsZeroCopy(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc set_output = mod.GetFunction("set_output_zero_copy", false); - for (size_t i = 0; i < outputs.size(); ++i) { - set_output(i, &outputs[i]); - } -} - -void TVM_VM_SetOutputsZeroCopy(TvmModule& mod, - std::vector& outputs) { - size_t num_total_args = outputs.size() + 1; - std::vector tvm_values(num_total_args); - std::vector tvm_type_codes(num_total_args); - tvm_rt::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - const std::string func_name = "main"; - setter(0, func_name.c_str()); - for (size_t k = 0; k < num_total_args - 1; ++k) { - setter(k + 1, &outputs[k]); - } - - TvmPackedFunc set_output = mod.GetFunction("set_outputs", false); - tvm_rt::TVMRetValue rv; - set_output.CallPacked(tvm_rt::TVMArgs(tvm_values.data(), tvm_type_codes.data(), gsl::narrow_cast(num_total_args)), &rv); -} - -void TVMGetOutputs(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < outputs.size(); ++i) { - get_output(i, &outputs[i]); - } -} - -void TVM_VM_GetOutputs(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < outputs.size(); ++i) { - // TODO(vvchernov): think about improvement of memory management - tvm_rt::NDArray output_array = get_output(i); - output_array.CopyTo(&outputs[i]); - } -} - -void TVMGetOutputShapes(TvmModule& mod, - TVMTensorShapes& output_shapes) { - size_t size = output_shapes.size(); - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < size; ++i) { - tvm_rt::NDArray output_array = get_output(i); - tvm_rt::ShapeTuple shape_tuple = output_array.Shape(); - size_t dims_num = shape_tuple.size(); - TensorShapeVector dims; - for (size_t j = 0; j < dims_num; ++j) { - dims.push_back(int64_t(shape_tuple[j])); - } - output_shapes[i] = dims; - } -} - -void TVMRun(TvmModule& mod) { - TvmPackedFunc run = mod.GetFunction("run", false); - ORT_ENFORCE(run != nullptr, "Unable to retrieve graph executor run."); - run(); -} - -void TVM_VM_Run(TvmModule& mod) { - TvmPackedFunc run = mod.GetFunction("invoke", false); - ORT_ENFORCE(run != nullptr, "Unable to retrieve virtual machine invoke."); - run("main"); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_api.h b/onnxruntime/core/providers/tvm/tvm_api.h deleted file mode 100644 index bbf05f4fc06d9..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_api.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_API_H -#define TVM_API_H - -#include -#include - -#include "tvm_common.h" -#include "tvm_defaults.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -TvmModule TVMCompile(const TvmEPOptions& options, - const std::string& onnx_txt, - const std::string& model_path, - int opset, - const TVMTensorShapes& input_shapes); -TvmModule TVMSoCompile(const TvmEPOptions& options); - -void TVMSetInputs(TvmModule& mod, std::vector& inds, std::vector& inputs); -void TVM_VM_SetInputs(TvmModule& mod, std::vector& inds, std::vector& inputs); -void TVMSetOutputsZeroCopy(TvmModule& mod, std::vector& outputs); -void TVM_VM_SetOutputsZeroCopy(TvmModule& mod, std::vector& outputs); -void TVMGetOutputs(TvmModule& mod, std::vector& outputs); -void TVM_VM_GetOutputs(TvmModule& mod, std::vector& outputs); -void TVMGetOutputShapes(TvmModule& mod, - TVMTensorShapes& output_shapes); -void TVMRun(TvmModule& mod); -void TVM_VM_Run(TvmModule& mod); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_API_H diff --git a/onnxruntime/core/providers/tvm/tvm_common.h b/onnxruntime/core/providers/tvm/tvm_common.h deleted file mode 100644 index 68e3b6496328a..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_common.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_COMMON_H -#define TVM_COMMON_H - -#include -#include - -#include -#include -#include - -namespace onnxruntime { -namespace tvm { - -using TvmModule = ::tvm::runtime::Module; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_COMMON_H diff --git a/onnxruntime/core/providers/tvm/tvm_compiler.cc b/onnxruntime/core/providers/tvm/tvm_compiler.cc deleted file mode 100644 index 8f4e7e7de9a36..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_compiler.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "tvm_compiler.h" -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -auto TVMCompilerBase::operator()(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) -> ModulePtr { - if (mod_) { - return mod_; - } - - mod_ = std::make_shared(); - this->compileTVMModule(options, input_shapes); - - return mod_; -} - -TVMCompiler::TVMCompiler(std::string&& onnx_model_str, - const std::string& model_path, - int opset) : onnx_model_str_(std::move(onnx_model_str)), - model_path_(model_path), - opset_(opset) { -} - -void TVMCompiler::compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) { - *mod_ = tvm::TVMCompile(options, - onnx_model_str_, - model_path_, - opset_, - input_shapes); - - onnx_model_str_.clear(); -} - -void TVMSoCompiler::compileTVMModule(const TvmEPOptions& options, - [[maybe_unused]] const TVMTensorShapes& input_shapes) { - *mod_ = tvm::TVMSoCompile(options); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_compiler.h b/onnxruntime/core/providers/tvm/tvm_compiler.h deleted file mode 100644 index bfc73d67aa07f..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_compiler.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_COMPILER_H -#define TVM_COMPILER_H - -#include -#include - -#include "tvm_common.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -class TVMCompilerBase { - public: - using ModulePtr = std::shared_ptr; - - TVMCompilerBase() = default; - virtual ~TVMCompilerBase() = default; - - ModulePtr operator()(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes); - - virtual void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) = 0; - - protected: - ModulePtr mod_; -}; - -class TVMCompiler : public TVMCompilerBase { - public: - TVMCompiler() = delete; - ~TVMCompiler() = default; - - TVMCompiler(std::string&& onnx_model_str, - const std::string& model_path, - int opset); - - void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) final; - - private: - std::string onnx_model_str_; - std::string model_path_; - int opset_; -}; - -class TVMSoCompiler : public TVMCompilerBase { - public: - TVMSoCompiler() = default; - ~TVMSoCompiler() = default; - - void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) final; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_COMPILER_H diff --git a/onnxruntime/core/providers/tvm/tvm_defaults.h b/onnxruntime/core/providers/tvm/tvm_defaults.h deleted file mode 100644 index 197d1f363c50d..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_defaults.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ - -#include - -namespace onnxruntime { -namespace tvm { - -namespace env_vars { -static const std::string kDumpSubgraphs = "ORT_TVM_DUMP_SUBGRAPHS"; -} // namespace env_vars - -constexpr const char* default_executor_type = "vm"; -constexpr const char* vm_executor_type = "vm"; -constexpr const char* graph_executor_type = "graph"; - -constexpr const char* default_target_str = "llvm"; -constexpr const char* llvm_target_str = "llvm"; - -constexpr const char* cpu_target_str = "cpu"; -constexpr const char* gpu_target_str = "gpu"; - -constexpr const char* default_tuning_type = "AutoTVM"; -constexpr const char* autotvm_tuning_type = "AutoTVM"; -constexpr const char* ansor_tuning_type = "Ansor"; - -constexpr const unsigned int default_opt_level = 3; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ diff --git a/onnxruntime/core/providers/tvm/tvm_ep_options.cc b/onnxruntime/core/providers/tvm/tvm_ep_options.cc deleted file mode 100644 index 70e99833cd78b..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_ep_options.cc +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/common/common.h" -#include "core/common/cpuid_info.h" -#include "core/framework/provider_options_utils.h" - -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -namespace provider_option_names { -constexpr const char* kExecutor = "executor"; -constexpr const char* kSoFolder = "so_folder"; -constexpr const char* kCheckHash = "check_hash"; -constexpr const char* kHashFilePath = "hash_file_path"; -constexpr const char* kTarget = "target"; -constexpr const char* kTargetHost = "target_host"; -constexpr const char* kOptLevel = "opt_level"; -constexpr const char* kFreezeWeights = "freeze_weights"; -constexpr const char* kSetOutputZeroCopy = "set_output_zero_copy"; -constexpr const char* kToNHWC = "to_nhwc"; -constexpr const char* kTuningFilePath = "tuning_file_path"; -constexpr const char* kTuningType = "tuning_type"; -constexpr const char* kInputNames = "input_names"; -constexpr const char* kInputShapes = "input_shapes"; - -static const std::unordered_set valid_keys{ - std::string{kExecutor}, - std::string{kSoFolder}, - std::string{kCheckHash}, - std::string{kHashFilePath}, - std::string{kTarget}, - std::string{kTargetHost}, - std::string{kOptLevel}, - std::string{kFreezeWeights}, - std::string{kSetOutputZeroCopy}, - std::string{kToNHWC}, - std::string{kTuningFilePath}, - std::string{kTuningType}, - std::string{kInputNames}, - std::string{kInputShapes}}; - -} // namespace provider_option_names - -size_t split(const std::string& src, std::vector& dst, char ch) { - dst.clear(); - - size_t pos = src.find(ch); - size_t initialPos = 0; - while (pos != std::string::npos) { - dst.push_back(src.substr(initialPos, pos - initialPos)); - initialPos = pos + 1; - - pos = src.find(ch, initialPos); - } - dst.push_back(src.substr(initialPos, std::min(pos, src.size()) - initialPos + 1)); - - return dst.size(); -} - -TvmEPOptions TvmEPOptionsHelper::FromOptionsString(const char* opt_str) { - std::string settings{opt_str}; - ProviderOptions options; - if (!settings.empty()) { - const std::string& str = settings; - - // tokenize settings - std::regex reg("\\s*,\\s*"); - std::sregex_token_iterator iter(str.begin(), str.end(), reg, -1); - std::sregex_token_iterator iter_end; - std::vector pairs(iter, iter_end); - - ORT_ENFORCE(pairs.size() > 0); - - for (const auto& pair : pairs) { - auto pos_colon = pair.find(':'); - ORT_ENFORCE(pos_colon != std::string::npos, "Invalid key value pair."); - std::string key = pair.substr(0, pos_colon); - std::string value = pair.substr(pos_colon + 1); - - // trim leading and trailing spaces from key/value - key = whitespace_trimming(key); - value = whitespace_trimming(value); - - // Check keys of obtained options - if (tvm::provider_option_names::valid_keys.count(key) == 0) { - ORT_NOT_IMPLEMENTED("TvmOptions: unknown option (", key, ")"); - } - - options[key] = value; - } - } - - return TvmEPOptionsHelper::FromProviderOptions(options); -} - -std::string TvmEPOptionsHelper::whitespace_trimming(const std::string& str) { - const std::string WHITESPACE = " \n\r\t\f\v"; - size_t start = str.find_first_not_of(WHITESPACE); - if (start == std::string::npos) { - return ""; - } else { - size_t end = str.find_last_not_of(WHITESPACE); - ORT_ENFORCE(end != std::string::npos); - return str.substr(start, end + 1); - } -} - -TvmEPOptions TvmEPOptionsHelper::FromProviderOptions(const ProviderOptions& pr_options) { - TvmEPOptions options{}; - - ORT_THROW_IF_ERROR( - ProviderOptionsParser{} - .AddAssignmentToReference(tvm::provider_option_names::kExecutor, options.executor) - .AddAssignmentToReference(tvm::provider_option_names::kSoFolder, options.so_folder) - .AddAssignmentToReference(tvm::provider_option_names::kCheckHash, options.check_hash) - .AddAssignmentToReference(tvm::provider_option_names::kHashFilePath, options.hash_file_path) - .AddAssignmentToReference(tvm::provider_option_names::kTarget, options.target) - .AddAssignmentToReference(tvm::provider_option_names::kTargetHost, options.target_host) - .AddAssignmentToReference(tvm::provider_option_names::kOptLevel, options.opt_level) - .AddAssignmentToReference(tvm::provider_option_names::kFreezeWeights, options.freeze_weights) - .AddAssignmentToReference(tvm::provider_option_names::kSetOutputZeroCopy, options.set_output_zero_copy) - .AddAssignmentToReference(tvm::provider_option_names::kToNHWC, options.to_nhwc) - .AddAssignmentToReference(tvm::provider_option_names::kTuningFilePath, options.tuning_file_path) - .AddAssignmentToReference(tvm::provider_option_names::kTuningType, options.tuning_type) - .AddAssignmentToReference(tvm::provider_option_names::kInputNames, options.input_names_str) - .AddAssignmentToReference(tvm::provider_option_names::kInputShapes, options.input_shapes_str) - .Parse(pr_options)); - - optionsPostprocess(options); - - return options; -} - -void TvmEPOptionsHelper::optionsPostprocess(TvmEPOptions& options) { - setInputShapes(options); - targetPostprocess(options.target); - targetHostPostprocess(options.target, options.target_host); - optLevelPostprocess(options.opt_level); -} - -bool TvmEPOptionsHelper::checkCPUTarget(const std::string& target) { - bool check = target.find("llvm") != std::string::npos; - return check; -} - -bool TvmEPOptionsHelper::checkGPUTarget(const std::string& target) { - bool check = (target.find("cuda") != std::string::npos || - target.find("opencl") != std::string::npos || - target.find("metal") != std::string::npos || - target.find("vulkan") != std::string::npos); - return check; -} - -void TvmEPOptionsHelper::setInputShapes(TvmEPOptions& options) { - if (options.input_names_str.empty() && options.input_shapes_str.empty()) - return; - ORT_ENFORCE(!options.input_names_str.empty() && !options.input_shapes_str.empty(), - "Both provider options \"input_names\" and \"input_shapes\" should be empty or full"); - - std::vector name_set; - std::string trimmed_names = whitespace_trimming(options.input_names_str); - size_t inp_tensors_num = split(trimmed_names, name_set, ' '); - ORT_ENFORCE(inp_tensors_num, "There is no any input tensor names!"); - - std::string trimmed_shapes = whitespace_trimming(options.input_shapes_str); - size_t end_pos = trimmed_shapes.find_last_of(']'); - ORT_ENFORCE(end_pos != std::string::npos, "Invalid string for input shapes. Symbol ] is not found"); - ORT_ENFORCE(end_pos == (trimmed_shapes.size() - 1), - "Invalid string for input shapes. Symbol ] should be last after whitespace trimming"); - - std::vector shape_set; - split(trimmed_shapes, shape_set, ']'); - shape_set.pop_back(); - ORT_ENFORCE(shape_set.size() == inp_tensors_num, - "Number of shapes is not the same as number of input tensor names"); - - for (size_t i = 0; i < inp_tensors_num; ++i) { - size_t pos = shape_set[i].find('['); - ORT_ENFORCE(pos != std::string::npos, "There is no symbol [ as pair for ]"); - std::string numbers = shape_set[i].substr(pos + 1); - std::vector number_set; - ORT_ENFORCE(split(numbers, number_set, ' '), "There is no any number between [ and ] symbols"); - - TensorShapeVector dims; - for (const auto& number : number_set) { - dims.push_back(std::stoi(number)); - } - - options.input_shapes[name_set[i]] = dims; - } -} - -void TvmEPOptionsHelper::targetPostprocess(std::string& target) { - if (target == tvm::cpu_target_str || - target == tvm::llvm_target_str) { - ProcessCPUTarget(target); - } else if (target == tvm::gpu_target_str) { - ProcessGPUTarget(); - } else if (target.empty()) { - ORT_NOT_IMPLEMENTED("target option is empty!"); - } else { - // TODO(vvchernov): extend mechanism of auto-definition of target - // target is gotten from option set up by client - } -} - -void TvmEPOptionsHelper::ProcessCPUTarget(std::string& target) { - const auto& cpu_id_info = CPUIDInfo::GetCPUIDInfo(); - // auto detect from CPU ID - if (cpu_id_info.HasAVX512Skylake()) { - target = tvm::cpu_targets::LLVM_TARGET_SKYLAKE_AVX512; - } else if (cpu_id_info.HasAVX512f()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX512; - } else if (cpu_id_info.HasAVX2()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX2; - } else if (cpu_id_info.HasAVX()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX; - } else { - // TODO(vvchernov): extend mechanism of auto-definition of cpu target - target = tvm::llvm_target_str; - } -} - -void TvmEPOptionsHelper::ProcessGPUTarget() { - ORT_NOT_IMPLEMENTED("GPU target auto-defenition is not implemented now!"); -} - -void TvmEPOptionsHelper::targetHostPostprocess(const std::string& target, std::string& target_host) { - if ((target_host == tvm::cpu_target_str || - target_host == tvm::llvm_target_str) && - target_host != target) { - target_host = target; - } else if (target_host.empty()) { - target_host = target; - } else { - // TODO(vvchernov): extend mechanism of auto-definition of target host - // target host is gotten from option set up by client - } -} - -void TvmEPOptionsHelper::optLevelPostprocess(unsigned int& opt_level) { - if (opt_level < 1) { - opt_level = tvm::default_opt_level; - } -} - -std::ostream& operator<<(std::ostream& out, const TvmEPOptions& options) { - out << "TVM EP options:\n" - << "executor type: " << options.executor << "\n" - << "so_folder: " << options.so_folder << "\n" - << "check_hash: " << options.check_hash << "\n" - << "hash_file_path: " << options.hash_file_path << "\n" - << "target: " << options.target << "\n" - << "target_host: " << options.target_host << "\n" - << "opt level: " << options.opt_level << "\n" - << "freeze weights: " << options.freeze_weights << "\n" - << "set_output_zero_copy: " << options.set_output_zero_copy << "\n" - << "tuning file path: " << options.tuning_file_path << "\n" - << "tuning type: " << options.tuning_type << "\n" - << "convert layout to NHWC: " << options.to_nhwc << "\n" - << "input tensor names: " << options.input_names_str << "\n" - << "input tensor shapes: " << options.input_shapes_str; - return out; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_ep_options.h b/onnxruntime/core/providers/tvm/tvm_ep_options.h deleted file mode 100644 index 0f2db30a3b304..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_ep_options.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_EXECUTION_PROVIDER_OPTIONS_H -#define TVM_EXECUTION_PROVIDER_OPTIONS_H - -#include -#include -#include -#include - -#include "core/framework/provider_options.h" -#include "core/framework/tensor_shape.h" - -#include "tvm_defaults.h" - -namespace onnxruntime { - -namespace tvm { -namespace cpu_targets { -// TODO(vvchernov): avx and avx512 need more careful differentiation for target -const std::string LLVM_TARGET_AVX = "llvm -mcpu=corei7-avx"; -const std::string LLVM_TARGET_AVX2 = "llvm -mcpu=core-avx2"; -const std::string LLVM_TARGET_SKYLAKE_AVX512 = "llvm -mcpu=skylake-avx512"; -const std::string LLVM_TARGET_AVX512 = "llvm -mcpu=skylake-avx512"; -} // namespace cpu_targets - -using TVMTensorShapes = std::vector; -using TVMInputShapes = std::unordered_map; -using InputsInfoMap = std::unordered_map; - -// Information needed to construct an TVM execution provider. -struct TvmEPOptions { - std::string executor{tvm::default_executor_type}; - std::string so_folder{""}; - bool check_hash = false; - std::string hash_file_path{""}; - std::string target{tvm::default_target_str}; - std::string target_host{tvm::default_target_str}; - unsigned int opt_level{tvm::default_opt_level}; - bool freeze_weights = true; - bool to_nhwc = false; - bool set_output_zero_copy = true; - std::string tuning_file_path{""}; - std::string tuning_type{tvm::default_tuning_type}; - std::string input_names_str{""}; - std::string input_shapes_str{""}; - TVMInputShapes input_shapes{}; - TVMTensorShapes output_shapes{}; -}; - -std::ostream& operator<<(std::ostream& out, const TvmEPOptions& options); - -class TvmEPOptionsHelper { - public: - static TvmEPOptions FromOptionsString(const char* options); - static TvmEPOptions FromProviderOptions(const ProviderOptions& options); - static std::string whitespace_trimming(const std::string& str); - - static bool checkCPUTarget(const std::string& target); - static bool checkGPUTarget(const std::string& target); - - private: - static void optionsPostprocess(TvmEPOptions& options); - static void setInputShapes(TvmEPOptions& options); - static void targetPostprocess(std::string& target); - static void ProcessCPUTarget(std::string& target); - static void ProcessGPUTarget(); - static void targetHostPostprocess(const std::string& target, std::string& target_host); - static void optLevelPostprocess(unsigned int& opt_level); -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_EXECUTION_PROVIDER_OPTIONS_H diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc deleted file mode 100644 index 61ee8f899dbf1..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "core/common/common.h" -#include "core/framework/execution_provider.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/compute_capability.h" -#include "core/graph/graph_proto_serializer.h" -#include "core/platform/env.h" -#include "core/graph/model.h" - -#include "tvm_execution_provider.h" -#include "xpu_data_transfer.h" -#include "tvm_allocator.h" -#include "tvm_utils.h" -#include "tvm_api.h" - -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace tvm { - -// Information to construct kernel function state. -struct TVMFuncState { - AllocateFunc allocate_func = nullptr; - DestroyFunc release_func = nullptr; - AllocatorHandle allocator = nullptr; - std::shared_ptr compiler = nullptr; -}; - -TvmExecutionProvider::TvmExecutionProvider(const TvmEPOptions& options) - : IExecutionProvider{kTvmExecutionProvider}, - options_{options} { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - // Get environment variables - const Env& env_instance = Env::Default(); - - const std::string dump_subgraphs_env = env_instance.GetEnvironmentVar(env_vars::kDumpSubgraphs); - if (!dump_subgraphs_env.empty()) { - dump_subgraphs_ = std::stoi(dump_subgraphs_env) != 0; - } -} - -std::vector TvmExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - return std::vector{CreateAllocator(default_memory_info)}; // TODO(leca): REVIEW: will CPU EP also use this? -} - -TvmExecutionProvider::~TvmExecutionProvider() {} - -std::vector> -TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - std::vector> result; - if (graph_viewer.IsSubgraph()) { - return result; - } - - const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - - std::unordered_set required_initializers; - const std::vector& sorted_nodes = graph_viewer.GetNodesInTopologicalOrder(); - std::unique_ptr sub_graph = std::make_unique(); - for (auto& node_idx : sorted_nodes) { - graph_viewer.GetNode(node_idx)->ForEachDef([&required_initializers, &init_tensors](const NodeArg& node_arg, bool is_input) { - if(is_input && init_tensors.count(node_arg.Name())) { - required_initializers.insert(node_arg.Name()); - } }, true); - } - - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "TVMStandalone"; - meta_def->domain = "StandaloneTest"; - std::vector inputs; - std::vector outputs; - - for (auto& nodeArgPtr : graph_viewer.GetInputs()) { - inputs.push_back(nodeArgPtr->Name()); - } - - for (auto& name : required_initializers) { - inputs.push_back(name); - } - - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { - outputs.push_back(nodeArgPtr->Name()); - } - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - sub_graph->SetMetaDef(std::move(meta_def)); - sub_graph->nodes = sorted_nodes; - result.push_back( - std::make_unique(std::move(sub_graph))); - return result; -} - -common::Status TvmExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - printOptions(); - for (auto& fused_node_graph : fused_nodes_and_graphs) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; - const std::string func_name = fused_node.Name(); - Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), graph_body_viewer.DomainToVersionMap(), - std::vector(), *GetLogger()); - ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - // TVM EP is using static lib approach, so invoke serializer directly. - GraphViewerToProto(graph_body_viewer, *model_proto.mutable_graph(), true, true); - auto opset = model_proto.add_opset_import(); - opset->set_domain(kOnnxDomain); - opset->set_version(graph_body_viewer.DomainToVersionMap().at(kOnnxDomain)); - - std::string onnx_model_str; - model_proto.SerializeToString(&onnx_model_str); - compilers_[func_name] = std::make_shared(std::move(onnx_model_str), - ToUTF8String(fused_node.ModelPath().ToPathString()), - int(opset->version())); - InputsInfoMap all_input_shapes; - auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes); - - std::vector output_tensors; - prepareOutputTensors(mod, output_tensors, graph_body_viewer.GetOutputs().size()); - - runners_[func_name] = std::make_shared(options_, mod, all_input_shapes, output_tensors); - - if (dump_subgraphs_) { - std::fstream dump("/tmp/" + func_name + ".onnx", - std::ios::out | std::ios::trunc | std::ios::binary); - model_proto.SerializeToOstream(&dump); - } - - // TODO(vvchernov): implement ops checking and mechanism of gracefully passing the responsibility to other EPs - // if the checking fails due to unsupported op(s) - NodeComputeInfo compute_info = prepareComputeInfo(func_name); - - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} - -std::unique_ptr TvmExecutionProvider::GetDataTransfer() const { - // TODO(vvchernov): target or target host? - if (TvmEPOptionsHelper::checkGPUTarget(options_.target)) { - return std::make_unique(); - } else if (TvmEPOptionsHelper::checkCPUTarget(options_.target)) { - return std::make_unique(); - } else { - ORT_NOT_IMPLEMENTED("TVM GetDataTransfer is not implemented for target ", options_.target); - } -} - -void TvmExecutionProvider::printOptions() { - LOGS(*GetLogger(), INFO) << options_; -} - -std::shared_ptr TvmExecutionProvider::compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& all_input_shapes) { - all_input_shapes.clear(); - - TVMTensorShapes input_shapes; - if (options_.freeze_weights) { - setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes); - } else { - setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes); - } - - std::shared_ptr mod = compilers_[func_name]->operator()(options_, input_shapes); - - return mod; -} - -void TvmExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - if (!graph_viewer.IsInitializedTensor(node->Name())) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - input_shapes.emplace_back(shape); - } - } -} - -void TvmExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - if (!graph_viewer.IsInitializedTensor(node->Name())) { - input_shapes.emplace_back(shape); - } - } -} - -TensorShapeVector TvmExecutionProvider::getInputShape(const NodeArg* node) { - TensorShapeVector shape; - const auto& node_name = node->Name(); - if (!options_.input_shapes.empty() && - options_.input_shapes.count(node_name)) { - shape = options_.input_shapes[node_name]; - } else { - shape = convertTensorShape(*node->Shape()); - } - - return shape; -} - -TensorShapeVector TvmExecutionProvider::convertTensorShape(const TensorShapeProto& shape_proto) { - TensorShape ort_shape = utils::GetTensorShapeFromTensorShapeProto(shape_proto); - size_t dims = ort_shape.NumDimensions(); - - TensorShapeVector shape(dims); - for (size_t j = 0; j < dims; ++j) { - int64_t dim = int64_t(ort_shape[j]); - ORT_ENFORCE(dim > 0, "Input dimension is not positive value (dim = " + std::to_string(dim) + "). " + - "Please use provider options to setup input_names and input_shapes"); - shape[j] = dim; - } - - return shape; -} - -void TvmExecutionProvider::prepareOutputTensors(const std::shared_ptr& mod, - std::vector& output_tensors, - size_t num) { - ORT_ENFORCE(mod != nullptr, "TVM module is not compiled"); - output_tensors.clear(); - options_.output_shapes.clear(); - options_.output_shapes.resize(num); - - if (options_.executor != "vm") { - TVMGetOutputShapes(*mod, options_.output_shapes); - } - - for (auto& output_shape : options_.output_shapes) { - DLTensor t; - // Draft for tensor, correct data is defined during inference - t.strides = nullptr; - t.byte_offset = 0; - t.data = nullptr; - if (options_.executor == "vm") { - t.ndim = 0; - t.shape = nullptr; - } else { - t.ndim = output_shape.size(); - t.shape = output_shape.data(); - } - - output_tensors.push_back(t); - } -} - -NodeComputeInfo TvmExecutionProvider::prepareComputeInfo(const std::string& func_name) { - NodeComputeInfo compute_info; - compute_info.create_state_func = std::bind(&TvmExecutionProvider::createStateFunc, - this, - std::placeholders::_1, - std::placeholders::_2); - - compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); - }; - - compute_info.compute_func = *runners_[func_name].get(); - - return compute_info; -} - -int TvmExecutionProvider::createStateFunc(ComputeContext* context, FunctionState* state) { - auto* state_ptr = new TVMFuncState(); - *state_ptr = {context->allocate_func, - context->release_func, - context->allocator_handle, - compilers_[context->node_name]}; - // TODO(vvchernov): Who and when release state? - *state = state_ptr; - return 0; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_execution_provider.h deleted file mode 100644 index e216570c2bebc..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_EXECUTION_PROVIDER_H -#define TVM_EXECUTION_PROVIDER_H - -#include -#include -#include -#include - -#include "core/common/logging/logging.h" -#include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" - -#include "tvm_compiler.h" -#include "tvm_runner.h" - -namespace onnxruntime { -class Graph; -class NodeArg; -namespace tvm { - -class TvmExecutionProvider : public IExecutionProvider { - using Compiler = TVMCompilerBase; - using Compilers = std::unordered_map>; - using Runner = TVMRunner; - using Runners = std::unordered_map>; - - public: - explicit TvmExecutionProvider(const TvmEPOptions& options); - virtual ~TvmExecutionProvider(); - - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; - - common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; - std::unique_ptr GetDataTransfer() const override; - std::vector CreatePreferredAllocators() override; - - private: - void printOptions(); - std::shared_ptr compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& inputs_info); // NOLINT - void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - TensorShapeVector getInputShape(const NodeArg* node); - TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto); - void prepareOutputTensors(const std::shared_ptr& mod, - std::vector& output_tensors, size_t num); // NOLINT - NodeComputeInfo prepareComputeInfo(const std::string& func_name); - int createStateFunc(ComputeContext*, FunctionState*); - - private: - TvmEPOptions options_; - Compilers compilers_; - Runners runners_; - bool dump_subgraphs_ = false; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_EXECUTION_PROVIDER_H diff --git a/onnxruntime/core/providers/tvm/tvm_provider_factory.cc b/onnxruntime/core/providers/tvm/tvm_provider_factory.cc deleted file mode 100644 index d83fd8ee4d1cb..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_provider_factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/providers/tvm/tvm_provider_factory.h" -#include "core/session/abi_session_options_impl.h" - -#include "tvm_execution_provider.h" -#include "tvm_provider_factory_creator.h" -#include "tvm_so_execution_provider.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { - -struct TvmProviderFactory : IExecutionProviderFactory { - TvmProviderFactory(const tvm::TvmEPOptions& options) : options_{options} {} - ~TvmProviderFactory() = default; - - std::unique_ptr CreateProvider() override { - std::unique_ptr provider = nullptr; - if (options_.so_folder != "") { - ORT_ENFORCE(options_.executor == "vm", - "Only virtual machine module is compiled from shared lib and dependences!"); - provider = std::move(std::make_unique(options_)); - } else { - provider = std::move(std::make_unique(options_)); - } - - return provider; - } - - private: - tvm::TvmEPOptions options_; -}; - -std::shared_ptr TVMProviderFactoryCreator::Create(const char* opt_str) { - tvm::TvmEPOptions options = tvm::TvmEPOptionsHelper::FromOptionsString(opt_str); - return std::make_shared(options); -} - -std::shared_ptr TVMProviderFactoryCreator::Create(const tvm::TvmEPOptions& options) { - return std::make_shared(options); -} -} // namespace onnxruntime - -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tvm, - _In_ OrtSessionOptions* options, - _In_ const char* opt_str) { - onnxruntime::tvm::TvmEPOptions tvm_options = onnxruntime::tvm::TvmEPOptionsHelper::FromOptionsString(opt_str); - options->provider_factories.push_back(onnxruntime::TVMProviderFactoryCreator::Create(tvm_options)); - return nullptr; -} diff --git a/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h b/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h deleted file mode 100644 index 2d7e06b5b7c59..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/providers.h" - -namespace onnxruntime { -namespace tvm { -struct TvmEPOptions; -} - -struct TVMProviderFactoryCreator { - static std::shared_ptr Create(const tvm::TvmEPOptions& options); - static std::shared_ptr Create(const char* params); -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner.cc b/onnxruntime/core/providers/tvm/tvm_runner.cc deleted file mode 100644 index 5dda8f5bf9c3e..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/graph/model.h" -#include "core/framework/tensorprotoutils.h" - -#include "tvm_runner.h" - -using namespace ONNX_NAMESPACE; -namespace onnxruntime { -namespace tvm { - -TVMRunner::TVMRunner(const TvmEPOptions& options, - const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const std::vector& output_tensors) { - runner_ = getTVMRunnerImpl(mod, options, inputs_info, output_tensors); -} - -common::Status TVMRunner::operator()(FunctionState state, const OrtApi* /*api*/, OrtKernelContext* context) { - Ort::KernelContext ctx(context); - return runner_->run(ctx); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner.h b/onnxruntime/core/providers/tvm/tvm_runner.h deleted file mode 100644 index 4b7349ee3405e..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_RUNNER_H -#define TVM_RUNNER_H - -#include -#include - -#include "tvm_runner_impl.h" - -namespace onnxruntime { -namespace tvm { - -class TVMRunner { - public: - TVMRunner() = delete; - virtual ~TVMRunner() = default; - - TVMRunner(const TvmEPOptions& options, - const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const std::vector& output_tensor); - - common::Status operator()(FunctionState state, const OrtApi* api, OrtKernelContext* context); - - private: - std::shared_ptr runner_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_TVM_RUNNER_H diff --git a/onnxruntime/core/providers/tvm/tvm_runner_impl.cc b/onnxruntime/core/providers/tvm/tvm_runner_impl.cc deleted file mode 100644 index c88de2652f14b..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner_impl.cc +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/tensorprotoutils.h" - -#include "tvm_runner_impl.h" -#include "tvm_utils.h" -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -/* ------------------------------------ RunnerImplFactory ----------------------------- */ - -std::shared_ptr getTVMRunnerImpl(const std::shared_ptr& mod, - const TvmEPOptions& options, - const InputsInfoMap& inputs_info, - const std::vector output_tensors) { - const std::string& name = options.executor; - if (name == "graph") { - return std::make_shared(mod, inputs_info, options.output_shapes, - output_tensors, options.set_output_zero_copy); - } else if (name == "vm") { - return std::make_shared(mod, inputs_info, options.output_shapes, - output_tensors, options.set_output_zero_copy); - } - return nullptr; -} - -/* ------------------------------------ RunnerImpl ------------------------------------ */ - -RunnerImpl::RunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : mod_(mod), - inputs_info_(inputs_info), - output_shapes_(output_shapes), - output_tensors_(output_tensors), - set_output_zero_copy_(set_output_zero_copy) { -} - -void RunnerImpl::convert_input_tensors2dl_tensors(Ort::KernelContext& context, - std::vector& dst, - std::vector& dst_inds) { - size_t num = inputs_info_.size(); - dst.reserve(num); - dst_inds.reserve(num); - for (auto& info : inputs_info_) { - // TODO(vvchernov): decomposition declaration only available with -std=c++1z or -std=gnu++1z - auto& i = info.first; - auto& shape = info.second; - - auto input_tensor = context.GetInput(i); - ORT_ENFORCE(input_tensor.IsTensor()); - - auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType(); - const auto tensor_type = input_tensor.GetTensorTypeAndShapeInfo().GetElementType(); - - DLTensor t; - t.device = GetDLDevice(ort_device_type); - t.dtype = GetDataType(tensor_type); - t.strides = nullptr; - t.byte_offset = 0; - t.data = const_cast(input_tensor.GetTensorRawData()); - t.ndim = shape.size(); - t.shape = shape.data(); - dst.emplace_back(t); - dst_inds.push_back(i); - } -} - -void RunnerImpl::add_device_type_data2output_tensors(Ort::KernelContext& context) { - size_t num_outputs = output_tensors_.size(); - for (auto i = 0u; i < num_outputs; i++) { - // setup output tensor property - auto output_tensor = context.GetOutput(i, - output_shapes_[i].data(), - output_shapes_[i].size()); - ORT_ENFORCE(output_tensor.IsTensor()); - - output_tensors_[i].device = - GetDLDevice(output_tensor.GetTensorMemoryInfo().GetDeviceType()); - output_tensors_[i].dtype = - GetDataType(output_tensor.GetTensorTypeAndShapeInfo().GetElementType()); - output_tensors_[i].data = output_tensor.GetTensorMutableRawData(); - } -} - -/* ------------------------------------ GERunnerImpl ------------------------------------ */ - -GERunnerImpl::GERunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : RunnerImpl(mod, inputs_info, output_shapes, output_tensors, set_output_zero_copy) { -} - -void GERunnerImpl::set_input(Ort::KernelContext& context) { - std::vector inds; - std::vector dl_tensors_inputs; - convert_input_tensors2dl_tensors(context, dl_tensors_inputs, inds); - - tvm::TVMSetInputs(*mod_, inds, dl_tensors_inputs); -} - -void GERunnerImpl::connect_output_tensors2ort(Ort::KernelContext& context) { - add_device_type_data2output_tensors(context); -} - -void GERunnerImpl::set_output_zero_copy() { - tvm::TVMSetOutputsZeroCopy(*mod_, output_tensors_); -} - -void GERunnerImpl::run() { - tvm::TVMRun(*mod_); -} - -void GERunnerImpl::get_outputs() { - tvm::TVMGetOutputs(*mod_, output_tensors_); -} - -/* ------------------------------------ VMRunnerImpl ------------------------------------ */ - -VMRunnerImpl::VMRunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : RunnerImpl(mod, inputs_info, output_shapes, output_tensors, set_output_zero_copy) { -} - -void VMRunnerImpl::set_input(Ort::KernelContext& context) { - std::vector inds; - std::vector dl_tensors_inputs; - convert_input_tensors2dl_tensors(context, dl_tensors_inputs, inds); - - tvm::TVM_VM_SetInputs(*mod_, inds, dl_tensors_inputs); -} - -void VMRunnerImpl::connect_output_tensors2ort(Ort::KernelContext& context) { - // TODO(vvchernov): try to find more flexible solution - if (!probe_infer_) { - infer_once_to_get_output_shapes(); - } - - add_device_type_data2output_tensors(context); -} - -void VMRunnerImpl::set_output_zero_copy() { - tvm::TVM_VM_SetOutputsZeroCopy(*mod_, output_tensors_); -} - -void VMRunnerImpl::run() { - tvm::TVM_VM_Run(*mod_); -} - -void VMRunnerImpl::get_outputs() { - tvm::TVM_VM_GetOutputs(*mod_, output_tensors_); -} - -void VMRunnerImpl::infer_once_to_get_output_shapes() { - run(); - size_t num_outputs = output_tensors_.size(); - // TODO(vvchernov): check it - output_shapes_.resize(num_outputs); - tvm::TVMGetOutputShapes(*mod_, output_shapes_); - for (size_t i = 0; i < num_outputs; ++i) { - output_tensors_[i].ndim = output_shapes_[i].size(); - output_tensors_[i].shape = output_shapes_[i].data(); - } - probe_infer_ = true; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner_impl.h b/onnxruntime/core/providers/tvm/tvm_runner_impl.h deleted file mode 100644 index 8c325303673b6..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner_impl.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_RUNNER_IMPL_H -#define TVM_RUNNER_IMPL_H - -#include -#include -#include - -#include "core/framework/func_api.h" -#include "core/session/onnxruntime_cxx_api.h" - -#include "tvm_common.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -class RunnerImpl { - public: - RunnerImpl() = delete; - RunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~RunnerImpl() = default; - - virtual common::Status run(Ort::KernelContext& context) { - common::Status res; - if (set_output_zero_copy_) { - res = run_without_output_copying(context); - } else { - res = run_with_output_copying(context); - } - return res; - } - - virtual common::Status run_without_output_copying(Ort::KernelContext& context) { - set_input(context); - connect_output_tensors2ort(context); - set_output_zero_copy(); - run(); - - return Status::OK(); - } - - virtual common::Status run_with_output_copying(Ort::KernelContext& context) { - set_input(context); - connect_output_tensors2ort(context); - run(); - get_outputs(); - - return Status::OK(); - } - - virtual void set_input(Ort::KernelContext& context) = 0; - virtual void connect_output_tensors2ort(Ort::KernelContext& context) = 0; - virtual void set_output_zero_copy() = 0; - virtual void run() = 0; - virtual void get_outputs() = 0; - - protected: - void convert_input_tensors2dl_tensors(Ort::KernelContext& context, - std::vector& dst, - std::vector& dst_inds); - void add_device_type_data2output_tensors(Ort::KernelContext& context); - - protected: - std::shared_ptr mod_; - InputsInfoMap inputs_info_; - TVMTensorShapes output_shapes_; - std::vector output_tensors_; - bool set_output_zero_copy_; -}; - -class GERunnerImpl : public RunnerImpl { - public: - GERunnerImpl() = delete; - GERunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~GERunnerImpl() = default; - - void set_input(Ort::KernelContext& context) final; - void connect_output_tensors2ort(Ort::KernelContext& context) final; - void set_output_zero_copy() final; - void run() final; - void get_outputs() final; -}; - -class VMRunnerImpl : public RunnerImpl { - public: - VMRunnerImpl() = delete; - VMRunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~VMRunnerImpl() = default; - - void set_input(Ort::KernelContext& context) final; - void connect_output_tensors2ort(Ort::KernelContext& context) final; - void set_output_zero_copy() final; - void run() final; - void get_outputs() final; - - private: - void infer_once_to_get_output_shapes(); - - private: - bool probe_infer_ = false; -}; - -std::shared_ptr getTVMRunnerImpl(const std::shared_ptr& mod, - const TvmEPOptions& options, - const InputsInfoMap& inputs_info, - const std::vector output_tensors); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_TVM_RUNNER_IMPL_H diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc deleted file mode 100644 index 029f25d6f292a..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "core/framework/execution_provider.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/compute_capability.h" -#include "core/platform/env.h" -#include "core/graph/model.h" - -#include "tvm_so_execution_provider.h" // NOLINT(build/include_subdir) -#include "xpu_data_transfer.h" // NOLINT(build/include_subdir) -#include "tvm_allocator.h" // NOLINT(build/include_subdir) -#include "tvm_utils.h" // NOLINT(build/include_subdir) -#include "tvm_api.h" // NOLINT(build/include_subdir) -#ifdef USE_TVM_HASH -#include "hash_alg/hasher.h" // NOLINT(build/include_subdir) -#endif - -using ONNX_NAMESPACE::TensorShapeProto; - -namespace onnxruntime { -namespace tvm { - -// Information to construct kernel function state. -struct TVMFuncState { - AllocateFunc allocate_func = nullptr; - DestroyFunc release_func = nullptr; - AllocatorHandle allocator = nullptr; - std::shared_ptr compiler = nullptr; -}; - -TvmSoExecutionProvider::TvmSoExecutionProvider(const TvmEPOptions& options) - : IExecutionProvider{kTvmExecutionProvider}, - options_{options} { - // Get environment variables - const Env& env_instance = Env::Default(); - - const std::string dump_subgraphs_env = env_instance.GetEnvironmentVar(env_vars::kDumpSubgraphs); - ORT_ENFORCE(dump_subgraphs_env.empty(), "TVM EP processing shared lib does not support subgraphs"); -} - -std::vector TvmSoExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - return std::vector{CreateAllocator(default_memory_info)}; -} - -TvmSoExecutionProvider::~TvmSoExecutionProvider() {} - -std::vector> -TvmSoExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - std::vector> result; - if (graph_viewer.IsSubgraph()) { - return result; - } - - const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - - std::unordered_set required_initializers; - const std::vector& sorted_nodes = graph_viewer.GetNodesInTopologicalOrder(); - std::unique_ptr sub_graph = std::make_unique(); - for (auto& node_idx : sorted_nodes) { - graph_viewer.GetNode(node_idx)->ForEachDef([&required_initializers, &init_tensors](const NodeArg& node_arg, bool is_input) { - if (is_input && init_tensors.count(node_arg.Name())) { - required_initializers.insert(node_arg.Name()); - } }, true); - } - - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "TVMStandalone"; - meta_def->domain = "StandaloneTest"; - std::vector inputs; - std::vector outputs; - - for (auto& nodeArgPtr : graph_viewer.GetInputs()) { - inputs.push_back(nodeArgPtr->Name()); - } - - for (auto& name : required_initializers) { - inputs.push_back(name); - } - - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { - outputs.push_back(nodeArgPtr->Name()); - } - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - sub_graph->SetMetaDef(std::move(meta_def)); - sub_graph->nodes = sorted_nodes; - result.push_back( - std::make_unique(std::move(sub_graph))); - return result; -} - -common::Status TvmSoExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - printOptions(); - for (auto& fused_node_graph : fused_nodes_and_graphs) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; -#ifdef USE_TVM_HASH - if (options_.check_hash) { - ORT_ENFORCE(checkHash(ToUTF8String(fused_node.ModelPath().ToPathString())), - "Hash check shows that used tuning files were not obtained for the given onnx-model"); - } -#endif - const std::string func_name = fused_node.Name(); - - compilers_[func_name] = std::make_shared(); - InputsInfoMap all_input_shapes; - auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes); - - std::vector output_tensors(graph_body_viewer.GetOutputs().size()); - prepareOutputTensors(output_tensors); - - runners_[func_name] = std::make_shared(options_, mod, all_input_shapes, output_tensors); - - // TODO(vvchernov): implement ops checking and mechanism of gracefully passing the responsibility to other EPs - // if the checking fails due to unsupported op(s) - NodeComputeInfo compute_info = prepareComputeInfo(func_name); - - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} - -std::unique_ptr TvmSoExecutionProvider::GetDataTransfer() const { - // TODO(vvchernov): target or target host? - if (TvmEPOptionsHelper::checkGPUTarget(options_.target)) { - return std::make_unique(); - } else if (TvmEPOptionsHelper::checkCPUTarget(options_.target)) { - return std::make_unique(); - } else { - ORT_NOT_IMPLEMENTED("TVM GetDataTransfer is not implemented for target ", options_.target); - } -} - -void TvmSoExecutionProvider::printOptions() { - LOGS(*GetLogger(), INFO) << options_; -} - -#ifdef USE_TVM_HASH -bool TvmSoExecutionProvider::checkHash(const std::string& onnx_path) const { - auto hasher = Hasher("sha256"); - std::string onnx_str = readFromFile(onnx_path); - std::string onnx_hash = hasher.hash(onnx_str.c_str(), onnx_str.size()); - onnx_str.clear(); - std::string hash; - if (options_.hash_file_path.empty()) { - // TODO(vvchernov): align hash file name with OctoML team - hash = readFromFile(options_.so_folder + "/hash.txt"); - } else { - hash = readFromFile(options_.hash_file_path); - } - return onnx_hash == hash; -} -#endif - -std::shared_ptr TvmSoExecutionProvider::compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& all_input_shapes) { - all_input_shapes.clear(); - - TVMTensorShapes input_shapes; - if (options_.freeze_weights) { - setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes); - } else { - setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes); - } - - std::shared_ptr mod = compilers_[func_name]->operator()(options_, input_shapes); - - return mod; -} - -void TvmSoExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - if (!graph_viewer.IsInitializedTensor(node->Name())) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - input_shapes.emplace_back(shape); - } - } -} - -void TvmSoExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - if (!graph_viewer.IsInitializedTensor(node->Name())) { - input_shapes.emplace_back(shape); - } - } -} - -TensorShapeVector TvmSoExecutionProvider::getInputShape(const NodeArg* node) { - TensorShapeVector shape; - const auto& node_name = node->Name(); - if (!options_.input_shapes.empty() && - options_.input_shapes.count(node_name)) { - shape = options_.input_shapes[node_name]; - } else { - shape = convertTensorShape(*node->Shape()); - } - - return shape; -} - -TensorShapeVector TvmSoExecutionProvider::convertTensorShape(const TensorShapeProto& shape_proto) { - TensorShape ort_shape = utils::GetTensorShapeFromTensorShapeProto(shape_proto); - size_t dims = ort_shape.NumDimensions(); - - TensorShapeVector shape(dims); - for (size_t j = 0; j < dims; ++j) { - int64_t dim = int64_t(ort_shape[j]); - ORT_ENFORCE(dim > 0, "Input dimension is not positive value (dim = " + std::to_string(dim) + "). " + - "Please use provider options to setup input_names and input_shapes"); - shape[j] = dim; - } - - return shape; -} - -void TvmSoExecutionProvider::prepareOutputTensors(std::vector& output_tensors) { - for (DLTensor& t : output_tensors) { - // Draft for tensor, correct data is defined during inference - t.strides = nullptr; - t.byte_offset = 0; - t.data = nullptr; - t.ndim = 0; - t.shape = nullptr; - } -} - -NodeComputeInfo TvmSoExecutionProvider::prepareComputeInfo(const std::string& func_name) { - NodeComputeInfo compute_info; - compute_info.create_state_func = std::bind(&TvmSoExecutionProvider::createStateFunc, - this, - std::placeholders::_1, - std::placeholders::_2); - - compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); - }; - - compute_info.compute_func = *runners_[func_name].get(); - - return compute_info; -} - -int TvmSoExecutionProvider::createStateFunc(ComputeContext* context, FunctionState* state) { - auto* state_ptr = new TVMFuncState(); - *state_ptr = {context->allocate_func, - context->release_func, - context->allocator_handle, - compilers_[context->node_name]}; - // TODO(vvchernov): Who and when release state? - *state = state_ptr; - return 0; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h deleted file mode 100644 index e155aca6e01f0..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ - -#include -#include -#include -#include - -#include "core/common/logging/logging.h" -#include "core/framework/execution_provider.h" -#include "core/platform/ort_mutex.h" - -#include "tvm_compiler.h" // NOLINT(build/include_subdir) -#include "tvm_runner.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -class Graph; -class NodeArg; -namespace tvm { - -class TvmSoExecutionProvider : public IExecutionProvider { - using Compiler = TVMCompilerBase; - using Compilers = std::unordered_map>; - using Runner = TVMRunner; - using Runners = std::unordered_map>; - - public: - explicit TvmSoExecutionProvider(const TvmEPOptions& options); - virtual ~TvmSoExecutionProvider(); - - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; - - common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; - std::unique_ptr GetDataTransfer() const override; - std::vector CreatePreferredAllocators() override; - - private: - void printOptions(); -#ifdef USE_TVM_HASH - bool checkHash(const std::string& onnx_path) const; -#endif - std::shared_ptr compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& inputs_info); // NOLINT - void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - TensorShapeVector getInputShape(const NodeArg* node); - TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto); - void prepareOutputTensors(std::vector& output_tensors); // NOLINT - NodeComputeInfo prepareComputeInfo(const std::string& func_name); - int createStateFunc(ComputeContext*, FunctionState*); - - private: - TvmEPOptions options_; - Compilers compilers_; - Runners runners_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ diff --git a/onnxruntime/core/providers/tvm/tvm_utils.cc b/onnxruntime/core/providers/tvm/tvm_utils.cc deleted file mode 100644 index e0a5b566835c8..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_utils.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_UTILS_H -#define TVM_UTILS_H - -#include -#include - -#include "tvm_utils.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -std::string readFromFile(const std::string& file_path) { - std::string str; - - std::ifstream t(file_path); - t.seekg(0, std::ios::end); - str.reserve(t.tellg()); - t.seekg(0, std::ios::beg); - - str.assign((std::istreambuf_iterator(t)), - std::istreambuf_iterator()); - return str; -} - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_UTILS_H diff --git a/onnxruntime/core/providers/tvm/tvm_utils.h b/onnxruntime/core/providers/tvm/tvm_utils.h deleted file mode 100644 index de77368c715b9..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_utils.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_UTILS_H -#define TVM_UTILS_H - -#include - -#include "tvm_common.h" - -#include "core/session/onnxruntime_cxx_api.h" -#include "core/framework/ortdevice.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm { - -inline DLDataType GetDataType(ONNXTensorElementDataType type) { - switch (type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return {kDLUInt, 8, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return {kDLInt, 8, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - return {kDLUInt, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - return {kDLInt, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return {kDLUInt, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return {kDLInt, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - return {kDLUInt, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return {kDLInt, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return {kDLFloat, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return {kDLFloat, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return {kDLFloat, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return {kDLUInt, 1, 1}; - default: - ORT_NOT_IMPLEMENTED("Unsupported data type"); - } -} - -inline DLDevice GetDLDevice(OrtMemoryInfoDeviceType device_type) { - DLDevice context; - switch (device_type) { - case OrtDevice::CPU: - context = {kDLCPU, 0}; - break; - case OrtDevice::GPU: - context = {kDLVulkan, 0}; - break; - default: - ORT_NOT_IMPLEMENTED("Unsupported device"); - break; - } - return context; -} - -std::string readFromFile(const std::string& file_path); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_UTILS_H diff --git a/onnxruntime/core/providers/tvm/xpu_data_transfer.cc b/onnxruntime/core/providers/tvm/xpu_data_transfer.cc deleted file mode 100644 index 4011dee7b7b7f..0000000000000 --- a/onnxruntime/core/providers/tvm/xpu_data_transfer.cc +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/tensor.h" - -#include "xpu_data_transfer.h" -#include "tvm_utils.h" - -namespace onnxruntime { -namespace tvm { - -XPUDataTransfer::XPUDataTransfer() { -} - -XPUDataTransfer::~XPUDataTransfer() { -} - -bool XPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) || - (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU); -} - -common::Status XPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - const auto src_device_type = src.Location().device.Type(); - const auto dst_device_type = dst.Location().device.Type(); - - if ((src_device_type == OrtDevice::CPU) && (dst_device_type == OrtDevice::CPU)) { - if (src_data == dst_data) { - // no need copying as both pointers are referring to same piece of memory. - return Status::OK(); - } - memcpy(dst_data, src_data, bytes); - } else { - DLTensor tvm_src, tvm_dst; - DLDataType dl_type{kDLInt, 8, 1}; - std::vector shape{int64_t(bytes)}; - // Construct source DLTensor - tvm_src.device = GetDLDevice(static_cast(src_device_type)); - tvm_src.dtype = dl_type; - tvm_src.strides = nullptr; - tvm_src.byte_offset = 0; - tvm_src.data = const_cast(src_data); - tvm_src.ndim = 1; - tvm_src.shape = shape.data(); - // Construct destination DLTensor - tvm_dst.device = GetDLDevice(static_cast(dst_device_type)); - tvm_dst.dtype = dl_type; - tvm_dst.strides = nullptr; - tvm_dst.byte_offset = 0; - tvm_dst.data = dst_data; - tvm_dst.ndim = 1; - tvm_dst.shape = shape.data(); - // Copying from src to dst - TVMDeviceCopyDataFromTo(&tvm_src, &tvm_dst, nullptr); - } - return Status::OK(); -} - -DLDevice XPUDataTransfer::get_context(const OrtDevice& device) const { - return GetDLDevice(static_cast(device.Type())); -} - -bool TvmCPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU; -} - -common::Status TvmCPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - if (src_data == dst_data) { - // no need copying as both pointers are referring to same piece of memory. - return Status::OK(); - } - // Copying only happens between two same size tensors. - ORT_ENFORCE(src.SizeInBytes() == dst.SizeInBytes()); - memcpy(dst_data, src_data, src.SizeInBytes()); - return Status::OK(); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/xpu_data_transfer.h b/onnxruntime/core/providers/tvm/xpu_data_transfer.h deleted file mode 100644 index a2cf55b241bb1..0000000000000 --- a/onnxruntime/core/providers/tvm/xpu_data_transfer.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef XPU_DATA_TRANSFER -#define XPU_DATA_TRANSFER - -#include "core/framework/data_transfer.h" -#include "tvm_common.h" - -namespace onnxruntime { -namespace tvm { - -class XPUDataTransfer : public IDataTransfer { - public: - XPUDataTransfer(); - ~XPUDataTransfer(); - - bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - - // Dumpen MSVC warning about not fully overriding - using IDataTransfer::CopyTensor; - common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; - DLDevice get_context(const OrtDevice& device) const; -}; - -class TvmCPUDataTransfer : public IDataTransfer { - public: - TvmCPUDataTransfer() = default; - // Dampen MSVC warning about not fully overriding CopyTensor - using IDataTransfer::CopyTensor; - bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // XPU_DATA_TRANSFER diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc index a9275b24ce91f..2b9ddf8ad147f 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc @@ -104,4 +104,8 @@ std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeP } return ret; } +std::string* attr_proto_release_string(ONNX_NAMESPACE::AttributeProto* attr) { + vai_assert(attr->type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr->name()); + return attr->release_s(); +} } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index bb2883512037b..08d980ec94c14 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -23,5 +23,6 @@ const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::A gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); +std::string* attr_proto_release_string(ONNX_NAMESPACE::AttributeProto* attr); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 41885721e7b9a..cccaa65de45f2 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -7,7 +7,9 @@ #include #include #include - +#ifdef _WIN32 +#include +#endif #include "./vai_assert.h" #include "core/common/exceptions.h" @@ -52,6 +54,13 @@ struct OrtVitisAIEpAPI { int (*vitisai_ep_on_run_start)( const std::vector>& eps, const void* state, vaip_core::DllSafe (*get_config_entry)(const void* state, const char* entry_name)) = nullptr; + int (*vitisai_ep_set_ep_dynamic_options)( + const std::vector>& eps, + const char* const* keys, + const char* const* values, size_t kv_len) = nullptr; + void (*profiler_collect)( + std::vector& api_events, + std::vector& kernel_events); void Ensure() { if (handle_) return; @@ -75,8 +84,10 @@ struct OrtVitisAIEpAPI { } std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); + std::ignore = env.GetSymbolFromLibrary(handle_, "profiler_collect", (void**)&profiler_collect); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options)); } private: @@ -90,6 +101,14 @@ static vaip_core::OrtApiForVaip the_global_api; std::shared_ptr get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; } const std::vector& get_domains_vitisaiep() { return s_domains_vitisaiep; } +void profiler_collect( + std::vector& api_events, + std::vector& kernel_events) { + if (s_library_vitisaiep.profiler_collect) { + s_library_vitisaiep.profiler_collect(api_events, kernel_events); + } +} + vaip_core::DllSafe>> compile_onnx_model( const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { auto model_path = graph_viewer.ModelPath().string(); @@ -118,6 +137,15 @@ int vitisai_ep_on_run_start( return 100; } +int vitisai_ep_set_ep_dynamic_options( + const std::vector>& eps, const char* const* keys, + const char* const* values, size_t kv_len) { + if (s_library_vitisaiep.vitisai_ep_set_ep_dynamic_options) { + return s_library_vitisaiep.vitisai_ep_set_ep_dynamic_options(eps, keys, values, kv_len); + } + return 100; +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -428,6 +456,18 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } }; the_global_api.node_arg_external_location = vaip::node_arg_external_location; + the_global_api.model_to_proto = [](onnxruntime::Model& model) { return model.ToProto().release(); }; + the_global_api.model_proto_serialize_as_string = [](ONNX_NAMESPACE::ModelProto& model_proto) { + return vaip_core::DllSafe(model_proto.SerializeAsString()); + }; + the_global_api.model_proto_delete = [](ONNX_NAMESPACE::ModelProto* p) { delete p; }; + the_global_api.attr_proto_release_string = [](ONNX_NAMESPACE::AttributeProto* attr) -> vaip_core::DllSafe { + auto pstr = vaip::attr_proto_release_string(attr); + std::string local_str = std::move(*pstr); + pstr = nullptr; + return vaip_core::DllSafe(std::move(local_str)); + }; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index 5d020e00ff5b7..64cf52ec0a404 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -25,18 +25,18 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_nodes() const = 0; virtual DllSafe> get_meta_def_constant_initializer() const = 0; + virtual bool get_meta_def_fallback_CPU() const { return false; }; virtual std::unique_ptr compile() const = 0; public: - inline void set_fused_node(const onnxruntime::Node* fused_node) { - fused_node_ = fused_node; - } - inline const onnxruntime::Node* get_fused_node() const { - return fused_node_; - } + inline void set_fused_node(const onnxruntime::Node* fused_node) { fused_node_ = fused_node; } + inline const onnxruntime::Node* get_fused_node() const { return fused_node_; } + inline void set_model(onnxruntime::Model* model) { model_ = model; } + inline onnxruntime::Model* get_model() const { return model_; } private: const onnxruntime::Node* fused_node_ = nullptr; + onnxruntime::Model* model_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 1a90f4c7fdebb..704b156dff57f 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -20,3 +20,22 @@ std::optional> create_ep_context_nodes( int vitisai_ep_on_run_start( const std::vector>& eps, const void* state, vaip_core::DllSafe (*get_config_entry)(const void* state, const char* entry_name)); +int vitisai_ep_set_ep_dynamic_options( + const std::vector>& eps, + const char* const* keys, + const char* const* values, size_t kv_len); +/** + * Replace EventRecord with std::tuple, + * because EventRecord is defined in profiler_common.h which is used inside onnxruntime. + * However, profiler_collect function will call vitis ep which can't include profiler_common.h. + */ +using EventInfo = std::tuple< + std::string, // name + int, // pid + int, // tid + long long, // timestamp + long long // duration + >; +void profiler_collect( + std::vector& api_events, + std::vector& kernel_events); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 74482d8e9ee0e..7628e45d2b933 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -20,6 +20,7 @@ struct NodeAttributes; namespace ONNX_NAMESPACE { struct AttributeProto; struct TensorProto; +struct ModelProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -70,6 +71,7 @@ enum AttributeProto_AttributeType : int { namespace vaip_core { class GraphHolder; using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::ModelProto; using ONNX_NAMESPACE::TensorProto; using onnxruntime::Graph; using onnxruntime::GraphViewer; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index bbe8b6e6e4934..9425c08dceebc 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (10u) +#define VAIP_ORT_API_MAJOR (12u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -231,6 +231,10 @@ struct OrtApiForVaip { gsl::span inputs); // [92] int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93] void (*session_option_configuration)(void* mmap, void* session_option, void (*push)(void* mmap, const char* name, const char* value)); // [94] + ModelProto* (*model_to_proto)(Model& model); // [95] + DllSafe (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96] + void (*model_proto_delete)(ModelProto* p); // [97] + DllSafe (*attr_proto_release_string)(AttributeProto* attr); // [98] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 09b115b4a57fc..3a99f56bb732a 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -1,6 +1,7 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "vitisai_execution_provider.h" +#include "vitisai_profiler.h" // Standard headers/libs. #include @@ -76,7 +77,17 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); + auto& ep = (**this->execution_providers_)[index]; + ep->set_fused_node(&fused_node_graph.fused_node.get()); + if (ep->get_meta_def_fallback_CPU()) { + auto& subgraph = fused_node_graph.filtered_graph.get(); + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model_proto = subgraph.CreateModel(logger)->ToProto(); + subgraph.ToProto(*model_proto->mutable_graph(), true, true); + auto local_registries = IOnnxRuntimeOpSchemaRegistryList{subgraph.GetSchemaRegistry()}; + auto model = Model::Create(std::move(*model_proto), subgraph.ModelPath(), &local_registries, logger); + ep->set_model(model.release()); + } compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; @@ -110,9 +121,23 @@ common::Status VitisAIExecutionProvider::OnRunStart(const onnxruntime::RunOption }; auto error_code = vitisai_ep_on_run_start(**execution_providers_, (const void*)&run_options, get_config_entry); if (error_code) { - return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, std::to_string(error_code)); + std::string error_msg = "vitisai_ep_on_run_start ret: " + std::to_string(error_code); + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, error_msg); } return Status::OK(); } +common::Status VitisAIExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + auto error_code = vitisai_ep_set_ep_dynamic_options(**execution_providers_, keys.data(), values.data(), std::min(keys.size(), values.size())); + if (error_code) { + std::string error_msg = "vitisai_ep_set_ep_dynamic_options ret: " + std::to_string(error_code); + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, error_msg); + } + return Status::OK(); +} + +std::unique_ptr VitisAIExecutionProvider::GetProfiler() { + return std::make_unique(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 05d2a976815b9..f0d1a289a2a73 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -36,9 +36,13 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetProfiler() override; + // This method is called after both `GetComputeCapabilityOps()` and `Compile()`. // This timing is required to work with both compliation-based EPs and non-compilation-based EPs. const InlinedVector GetEpContextNodes() const override; + virtual common::Status SetEpDynamicOptions(gsl::span /*keys*/, + gsl::span /*values*/) override; private: using my_ep_t = vaip_core::DllSafe>>; @@ -48,10 +52,9 @@ class VitisAIExecutionProvider : public IExecutionProvider { ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; - std::set vitisai_optypes_; // EP context related. bool ep_ctx_enabled_ = false; - bool ep_ctx_embed_mode_ = true; + bool ep_ctx_embed_mode_ = false; std::string ep_ctx_model_path_cfg_{""}; mutable PathString ep_ctx_model_file_loc_{}; // It might need to be called before loading diff --git a/onnxruntime/core/providers/vitisai/vitisai_profiler.cc b/onnxruntime/core/providers/vitisai/vitisai_profiler.cc new file mode 100644 index 0000000000000..d84507ec6ad02 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/vitisai_profiler.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "vitisai_profiler.h" + +namespace onnxruntime { +namespace profiling { + +#if defined(USE_VITISAI) + +bool VitisaiProfiler::StartProfiling(TimePoint tp) { + return true; +} + +void VitisaiProfiler::EndProfiling(TimePoint tp, Events& events) { + auto time_point = + std::chrono::duration_cast(tp.time_since_epoch()).count(); + + std::vector api_events; + std::vector kernel_events; + profiler_collect(api_events, kernel_events); + + std::unordered_map event_args; + + for (auto& a : api_events) { + events.emplace_back(EventCategory::API_EVENT, + std::get<1>(a), // pid + std::get<2>(a), // tid + std::get<0>(a), // name + std::get<3>(a) - time_point, // timestamp + std::get<4>(a), // duration + event_args); + } + + for (auto& k : kernel_events) { + events.emplace_back(EventCategory::KERNEL_EVENT, + std::get<1>(k), + std::get<2>(k), + std::get<0>(k), + std::get<3>(k) - time_point, + std::get<4>(k), + event_args); + } +} + +#endif + +} // namespace profiling +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_profiler.h b/onnxruntime/core/providers/vitisai/vitisai_profiler.h new file mode 100644 index 0000000000000..aedbda31f7b1d --- /dev/null +++ b/onnxruntime/core/providers/vitisai/vitisai_profiler.h @@ -0,0 +1,23 @@ +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/vitisai/include/vaip/global_api.h" + +namespace onnxruntime { +namespace profiling { + +#if defined(USE_VITISAI) +class VitisaiProfiler final : public EpProfiler { + public: + VitisaiProfiler() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(VitisaiProfiler); + ~VitisaiProfiler() {} + bool StartProfiling(TimePoint) override; + void EndProfiling(TimePoint, Events&) override; + void Start(uint64_t) override{}; + void Stop(uint64_t) override{}; +}; +#endif + +} // namespace profiling +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h index 3ed432c2efa1c..5278efdb4a400 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -112,7 +112,7 @@ class ConvOpBuilder : public BaseOpBuilder { } } } else { - auto pads = helper.Get("pads", std::vector{0U, 0U}); + auto pads = helper.Get("pads", std::vector{0U, 0U, 0U, 0U}); if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { if (is_1d_conv) { op = graph_ep->GetGraph() diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h index 4c10ba01b1c2e..7da1e6e674601 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -65,6 +65,12 @@ ELEMENTWISE_OP_BUILDER(Floor, Floor); ELEMENTWISE_OP_BUILDER(Log, Log); ELEMENTWISE_OP_BUILDER(Sin, Sin); ELEMENTWISE_OP_BUILDER(HardSwish, HardSwish); +ELEMENTWISE_OP_BUILDER(Neg, Neg); +ELEMENTWISE_OP_BUILDER(Not, LogicalNot); +ELEMENTWISE_OP_BUILDER(Ceil, Ceil); +ELEMENTWISE_OP_BUILDER(Round, Round); +ELEMENTWISE_OP_BUILDER(Min, Minimum); +ELEMENTWISE_OP_BUILDER(Max, Maximum); class PowOpBuilder : public BaseOpBuilder { bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h new file mode 100644 index 0000000000000..19cbe4e7f3e48 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h @@ -0,0 +1,191 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +typedef tim::vx::ops::PadV2::pad_mode_type PadMode; + +class PadOpBuilder : public BaseOpBuilder { + public: + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + const auto& initializers = graph_viewer.GetAllInitializedTensors(); + + if (mode == "wrap") { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + if (mode == "constant") { + if (num_inputs > 2 && input_defs[2]->Exists()) { + // only support if `constant_value` input is a constant initializer + if (!Contains(initializers, input_defs[2]->Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } + } + // only support if `pads` input is known and does not contain negative values + { + const auto* pads_initializer = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!pads_initializer) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer"; + return false; + } + + Initializer unpacked_tensor(*pads_initializer); + auto tensor_data = unpacked_tensor.DataAsSpan(); + for (size_t i = 0; i < unpacked_tensor.size(); i++) { + if (tensor_data[i] < 0) { + LOGS_DEFAULT(WARNING) << "Negative pad value is not supported: pads[" + << i << "] = " << tensor_data[i]; + return false; + } + } + } + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unspport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (1 == i) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer."; + return false; + } + } else if (2 == i) { + if (iodef.node_arg.Exists() && !Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } else if (i == 3) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "axes must be a constant initializer.."; + return false; + } + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Pad Op."; + NodeAttrHelper helper(node_unit); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node_unit.Inputs(); + PadMode pad_mode = PadMode::PAD_MODE_CONSTANT; + float const_val = 0.0f; + std::vector axes_tensor_data; + int32_t input_rank = inputs[0]->GetShape().size(); + + if (mode == "constant") { + pad_mode = PadMode::PAD_MODE_CONSTANT; + } else if (mode == "reflect") { + pad_mode = PadMode::PAD_MODE_REFLECT; + } else if (mode == "edge") { + pad_mode = PadMode::PAD_MODE_EDGE; + } else { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + + // `pads` input + std::vector onnx_pads(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(onnx_pads.data()); + + // `constant_value` input + if (inputs.size() > 2 && pad_mode == PadMode::PAD_MODE_CONSTANT) { + if (input_defs[2].node_arg.Exists()) { + inputs[2]->CopyDataFromTensor(&const_val); + } + } + // `axes` input + if (inputs.size() > 3) { + // optional input axes is provided, use axes initializer data + std::vector axes_tensor(inputs[3]->GetSpec().GetElementNum()); + inputs[3]->CopyDataFromTensor(axes_tensor.data()); + std::transform( + axes_tensor.begin(), axes_tensor.end(), std::back_inserter(axes_tensor_data), + [input_rank](int64_t axis) { return HandleNegativeAxis(axis, input_rank); }); + } else { + // if not provided, make a default axes as [0, 1, ..., input_rank - 1] + std::vector default_axes(input_rank); + std::iota(std::begin(default_axes), std::end(default_axes), 0); + axes_tensor_data = std::move(default_axes); + } + + int64_t num_axes = axes_tensor_data.size(); + std::vector front_size(input_rank, 0); + std::vector back_size(input_rank, 0); + + int64_t axes_index = 0; + for (int64_t axes : axes_tensor_data) { + front_size[axes] = onnx_pads[axes_index]; + back_size[axes] = onnx_pads[axes_index + num_axes]; + axes_index++; + } + + std::reverse(front_size.begin(), front_size.end()); + std::reverse(back_size.begin(), back_size.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + front_size, back_size, const_val, pad_mode); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h new file mode 100644 index 0000000000000..e08416bda70d4 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h @@ -0,0 +1,190 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class SplitOpBuilder : public BaseOpBuilder { + public: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", 0); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + size_t num_outputs = node->OutputDefs().size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + std::vector splits_list; + bool split_provided = false; + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(WARNING) << "Axis is invalid in Split. Axis(" << axis + << ") is out of rank[" << -rank << "," << rank - 1 << "]"; + return false; + } + axis = HandleNegativeAxis(axis, rank); + const auto split_dims_at_axis = input_shape.GetDims()[axis]; + if (num_inputs > 1 && input_defs[1]->Exists()) { + // if optional input `split` is provided + const auto* splits = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + Initializer unpacked_tensor(*splits); + auto split_sizes_ = unpacked_tensor.DataAsSpan(); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } + if (num_inputs == 1) { + // opset1,2,11 split as attribute + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } else if (node->SinceVersion() >= 18) { + const auto outputs_count = helper.GetInt64("num_outputs"); + if (!outputs_count.has_value()) { + LOGS_DEFAULT(WARNING) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (outputs_count.value() != static_cast(num_outputs) || + outputs_count.value() > split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " + << outputs_count.value(); + return false; + } + } + } + if (!split_provided) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + } + + uint32_t sum_of_splits = std::accumulate(splits_list.begin(), splits_list.end(), SafeInt(0)); + if (sum_of_splits != split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Sum of the 'split' input values must equal to the dim value at 'axis' specified. " + << "dim value at 'axis' specified: " + << split_dims_at_axis + << ", sum of 'split' input values: " + << sum_of_splits; + return false; + } + if (!std::all_of(splits_list.begin(), splits_list.end(), [](int64_t value) { return value >= 0; })) { + LOGS_DEFAULT(WARNING) << "Invalid value in 'split' attribute. All values must be > 0"; + return false; + } + auto average_split = sum_of_splits / num_outputs; + if (!std::all_of(splits_list.begin(), splits_list.end(), [average_split](int64_t value) { return value == average_split; })) { + // TO DO, remove this check after driver supports it. + LOGS_DEFAULT(WARNING) << "Uneven splits are not currently supported for now."; + return false; + } + + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unsupport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Split Op."; + NodeAttrHelper helper(node_unit); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + const auto split_dims_at_axis = inputs[0]->GetShape()[axis]; + auto num_outputs = outputs.size(); + // transform splite vector to timvx slice + std::vector onnx_split; + if (inputs.size() > 1) { + std::vector split_sizes_(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(split_sizes_.data()); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (inputs.size() == 1) { + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (node_unit.SinceVersion() >= 18 || !helper.HasAttr("split")) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + } + std::vector slices(onnx_split.begin(), onnx_split.end()); + std::reverse(slices.begin(), slices.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + axis, slices); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h index dc0969429b8ff..fcf9479a6058b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -53,6 +53,8 @@ #include "impl/cast_op_builder.h" #include "impl/dropout_op_builder.h" #include "impl/slice_op_builder.h" +#include "impl/split_op_builder.h" +#include "impl/pad_op_builder.h" namespace onnxruntime { namespace vsi { namespace npu { @@ -110,7 +112,15 @@ static const std::map reg = { REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), REGISTER_OP_BUILDER("Cast", CastOpBuilder), REGISTER_OP_BUILDER("Dropout", DropoutOpBuilder), - REGISTER_OP_BUILDER("Slice", SliceOpBuilder) + REGISTER_OP_BUILDER("Slice", SliceOpBuilder), + REGISTER_OP_BUILDER("Split", SplitOpBuilder), + REGISTER_OP_BUILDER("Neg", NegOpBuilder), + REGISTER_OP_BUILDER("Not", NotOpBuilder), + REGISTER_OP_BUILDER("Ceil", CeilOpBuilder), + REGISTER_OP_BUILDER("Round", RoundOpBuilder), + REGISTER_OP_BUILDER("Min", MinOpBuilder), + REGISTER_OP_BUILDER("Max", MaxOpBuilder), + REGISTER_OP_BUILDER("Pad", PadOpBuilder) #undef REGISTER_OP_BUILDER }; diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index 45de47f3e5128..95a4e4650e9fe 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -1,8 +1,8 @@ diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake -index c02ac2096d..2bc51298f0 100644 +index 10c307b3b9..a52bf71c4d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake -@@ -361,7 +361,7 @@ else() +@@ -370,7 +370,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -12,11 +12,11 @@ index c02ac2096d..2bc51298f0 100644 ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h -index e46105324a..414c46a1ce 100644 +index 28ae64c4d5..0c77e0ca78 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h -@@ -82,6 +82,9 @@ Abstract: - +@@ -83,6 +83,9 @@ Abstract: + #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) +#if !defined(USE_VSINPU) @@ -25,51 +25,51 @@ index e46105324a..414c46a1ce 100644 #if !defined(__APPLE__) // Had to temporary disable fp16 under APPLE ARM64, as compiling // the source files require a hardware specific compilation flag. -@@ -90,6 +93,7 @@ Abstract: - +@@ -91,6 +94,7 @@ Abstract: + #define MLAS_F16VEC_INTRINSICS_SUPPORTED - + +#endif // #endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic -@@ -1635,6 +1639,7 @@ MlasHalfGemmConvertPackB( +@@ -1644,6 +1648,7 @@ MlasHalfGemmConvertPackB( ); - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) /** * @brief Whether current CPU supports Bfloat16(bf16) acceleration. */ -@@ -1746,6 +1751,7 @@ MlasSBGemmPackBSize(size_t N, size_t K); +@@ -1755,6 +1760,7 @@ MlasSBGemmPackBSize(size_t N, size_t K); void MLASCALL MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); #endif +#endif - + /** * @brief Indirect Depthwise convolution for fp16 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h -index 4239e2ecae..3df7e5573d 100644 +index 0533a5e49b..c18bf7f90d 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h -@@ -361,6 +361,7 @@ size_t +@@ -377,6 +377,7 @@ size_t #else - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( const float* A, const bfloat16_t* B, -@@ -373,6 +374,7 @@ typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( +@@ -389,6 +390,7 @@ typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( const float* Bias ); #endif +#endif - + typedef size_t -@@ -763,8 +765,10 @@ extern "C" { +@@ -796,8 +798,10 @@ extern "C" { MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; #if defined(__aarch64__) && defined(__linux__) @@ -80,39 +80,25 @@ index 4239e2ecae..3df7e5573d 100644 #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; -@@ -899,8 +903,10 @@ extern "C" { +@@ -946,8 +950,10 @@ extern "C" { #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #endif +#endif - + // // Single-threaded single precision matrix/matrix multiply operation. -@@ -2570,4 +2576,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH - static_assert(std::is_same_v || std::is_same_v); - *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); - } -- diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp -index ed437f20f7..8c9d0a75fd 100644 +index b3c9461293..424c3b0441 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp -@@ -20,7 +20,7 @@ Abstract: - #include - #include - --#if defined(MLAS_TARGET_POWER) -+#if defined(MLAS_TARGET_POWER) - #if defined(__linux__) - #include - #elif defined(_AIX) -@@ -536,7 +536,7 @@ Return Value: - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; +@@ -574,7 +574,7 @@ Return Value: + this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } - + -#if defined(__linux__) +#if defined(__linux__) && !defined(USE_VSINPU) // @@ -124,12 +110,12 @@ index de7fd72fad..4f75dbd6fa 100644 +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -31,6 +31,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -396,4 +397,5 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat } ); @@ -137,11 +123,11 @@ index de7fd72fad..4f75dbd6fa 100644 +#endif #endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc -index 6a71283f9d..d8bd348854 100644 +index 2c6d23e4de..61aaacdfd6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc -@@ -132,7 +132,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { - +@@ -133,7 +133,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { + return Status::OK(); } -#if defined(__aarch64__) && defined(__linux__) @@ -149,7 +135,7 @@ index 6a71283f9d..d8bd348854 100644 bool GemmPackBBfloat16(AllocatorPtr& alloc, const Tensor& tensor_b, bool trans_b, -@@ -180,6 +180,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc +@@ -181,6 +181,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (input_idx == 1) { size_t packed_b_size; #if defined(__aarch64__) && defined(__linux__) @@ -157,7 +143,7 @@ index 6a71283f9d..d8bd348854 100644 size_t dim1 = 0; size_t dim2 = 0; TensorShape b_shape = tensor.Shape(); -@@ -192,6 +193,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc +@@ -193,6 +194,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); } else @@ -165,7 +151,7 @@ index 6a71283f9d..d8bd348854 100644 #endif { is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); -@@ -257,6 +259,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { +@@ -259,6 +261,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) @@ -173,7 +159,7 @@ index 6a71283f9d..d8bd348854 100644 if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { -@@ -273,6 +276,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { +@@ -275,6 +278,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { } MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); } else @@ -187,7 +173,7 @@ index b9bbe36583..2f570502d2 100644 +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -31,8 +31,10 @@ class MatMul final : public OpKernel { trans_batch_b_ = trans_batch_b_attr != 0; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); @@ -195,10 +181,10 @@ index b9bbe36583..2f570502d2 100644 +#endif #endif } - + @@ -57,12 +59,14 @@ class MatMul final : public OpKernel { bool trans_batch_b_; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) // fastmath mode state @@ -209,7 +195,7 @@ index b9bbe36583..2f570502d2 100644 #endif +#endif }; - + } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index f85fe97776..6039b7fa9e 100644 @@ -217,12 +203,12 @@ index f85fe97776..6039b7fa9e 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #include "test_sbgemm.h" - + @@ -138,4 +139,5 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return SBGemmRegistLongExecute() > 0; @@ -235,15 +221,15 @@ index 13701e2e3d..7e432f53c2 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -278,4 +279,5 @@ class MlasSBGemmTest : public MlasTestBase { } }; - + +#endif #endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index bbf8255ac2940..db8a87d9eaf24 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -34,7 +34,8 @@ namespace onnxruntime { namespace vsi { namespace npu { -GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) + : graph_viewer_(graph_viewer), logger_(logger) { Prepare(); context_ = tim::vx::Context::Create(); graph_ = context_->CreateGraph(); @@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g } bool GraphEP::Prepare() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); for (const auto& node_unit : node_unit_holder_) { auto quant_op_type = util::GetQuantizedOpType(*node_unit); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index 49344770d060e..5bb332fad0177 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -51,7 +51,7 @@ struct NodeIOInfo { class GraphEP { public: - explicit GraphEP(const GraphViewer& graph_viewer); + explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger); ~GraphEP() {} bool Prepare(); @@ -104,6 +104,7 @@ class GraphEP { // In the form of {input_name, [NodeUnit(s) using the input]} std::unordered_map> all_quantized_op_inputs_; const GraphViewer& graph_viewer_; + const logging::Logger& logger_; // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 466fe1f82461c..7da7cc6cb63ba 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; if (graph_viewer.IsSubgraph()) { @@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, - OrtKernelContext* context) { + OrtKernelContext* context, + const logging::Logger& logger) { Ort::KernelContext ctx(context); size_t num_in = ctx.GetInputCount(); const size_t num_inputs = graph_ep->GetGraphInputs().size(); @@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, } if (!graph_ep->GetGraph()->Run()) { - LOGS_DEFAULT(ERROR) << "Failed to run graph."; + LOGS(logger, ERROR) << "Failed to run graph."; } for (size_t i = 0; i < ctx.GetOutputCount(); i++) { auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; @@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { + const auto& logger = *GetLogger(); + for (const auto& fused_node_graph : fused_nodes_and_graphs) { const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; - std::shared_ptr graph_ep = std::make_shared(graph_viewer); + std::shared_ptr graph_ep = std::make_shared(graph_viewer, logger); for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { - LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" << graph_viewer.IsInitializedTensor(tensor->Name()); auto input = std::make_shared(); input->name = tensor->Name(); @@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu graph_ep->GetGraphInputs().push_back(input); } for (auto tensor : graph_viewer.GetOutputs()) { - LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); auto output = std::make_shared(); output->name = tensor->Name(); output->is_initializer = false; @@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu if (node != &node_unit.GetNode()) { continue; } - LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]"; vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); } - LOGS_DEFAULT(INFO) << "Verifying graph"; + LOGS(logger, INFO) << "Verifying graph"; graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); if (!graph_ep->GetCompiled()) { - LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + LOGS(logger, ERROR) << "Failed to verify graph."; } else { - LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + LOGS(logger, INFO) << "Graph has been verified successfully."; } NodeComputeInfo compute_info; @@ -258,8 +262,8 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu compute_info.compute_func = [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, OrtKernelContext* context) { - std::lock_guard lock(this->GetMutex()); - Status res = ComputeStateFunc(graph_ep.get(), context); + std::lock_guard lock(this->GetMutex()); + Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger()); return res; }; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 44318c332fdd0..c2605eb65faee 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -43,11 +43,11 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; - OrtMutex& GetMutex() { return mutex_; } + std::mutex& GetMutex() { return mutex_; } private: int device_id_; - OrtMutex mutex_; + std::mutex mutex_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc new file mode 100644 index 0000000000000..8e27acdc285d4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/framework/session_state.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +void* GpuBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + auto buffer = context_.BufferManager().Create(size); + + stats_.num_allocs++; + return buffer; +} + +void GpuBufferAllocator::Free(void* p) { + if (p != nullptr) { + context_.BufferManager().Release(static_cast(p)); + stats_.num_allocs--; + } +} + +void GpuBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h new file mode 100644 index 0000000000000..51ca65a8b4822 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class GpuBufferAllocator : public IAllocator { + public: + GpuBufferAllocator(const WebGpuContext& context) + : IAllocator( + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)), + context_{context} { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc new file mode 100644 index 0000000000000..45eb123943de9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +constexpr size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + +class DisabledCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + // always return empty buffer + return nullptr; + } + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + void ReleaseBuffer(WGPUBuffer buffer) override { + wgpuBufferRelease(buffer); + } + + void OnRefresh() override { + // no-op + } +}; + +class LazyReleaseCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + pending_buffers_.clear(); + } + + std::vector pending_buffers_; +}; + +class SimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + } + pending_buffers_.clear(); + } + + std::map> buffers_; + std::vector pending_buffers_; +}; + +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { + {64, 250}, + {128, 200}, + {256, 200}, + {512, 200}, + {2048, 230}, + {4096, 200}, + {8192, 50}, + {16384, 50}, + {32768, 50}, + {65536, 50}, + {131072, 50}, + {262144, 50}, + {524288, 50}, + {1048576, 50}, + {2097152, 30}, + {4194304, 20}, + {8388608, 10}, + {12582912, 10}, + {16777216, 10}, + {26214400, 15}, + {33554432, 22}, + {44236800, 2}, + {58982400, 6}, + // we don't want to cache the bucket sizes below but not caching them + // results in some major performance hits for models like sd-turbo. + {67108864, 6}, + {134217728, 6}, + {167772160, 6}, +}; + +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + + pending_buffers_.clear(); + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_.emplace(pair.first, std::vector()); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }), + "Bucket sizes must be multiples of 16."); + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { + switch (cache_mode) { + case BufferCacheMode::Disabled: + return std::make_unique(); + case BufferCacheMode::LazyRelease: + return std::make_unique(); + case BufferCacheMode::Simple: + return std::make_unique(); + case BufferCacheMode::Bucket: + return std::make_unique(); + default: + ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); + } +} + +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { + switch (mode) { + case BufferCacheMode::Disabled: + os << "Disabled"; + break; + case BufferCacheMode::LazyRelease: + os << "LazyRelease"; + break; + case BufferCacheMode::Simple: + os << "Simple"; + break; + case BufferCacheMode::Bucket: + os << "Bucket"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + +BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) + : context_{context}, + storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)}, + uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)}, + query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)}, + default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { +} + +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + pending_staging_buffers_.push_back(staging_buffer); +} + +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); +} + +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { + auto& cache = GetCacheManager(static_cast(usage)); + auto buffer_size = cache.CalculateBufferSize(size); + + auto buffer = cache.TryAcquireCachedBuffer(buffer_size); + if (buffer) { + return buffer; + } + + // cache miss, create a new buffer + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = usage; + // desc.label = std::to_string(xx++).c_str(); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); + + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); + + cache.RegisterBuffer(buffer, size); + return buffer; +} + +void BufferManager::Release(WGPUBuffer buffer) { + GetCacheManager(buffer).ReleaseBuffer(buffer); +} + +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + // TODO: revise wait in whole project + + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + + auto mapped_data = staging_buffer.GetConstMappedRange(); + memcpy(dst, mapped_data, size); +} + +void BufferManager::RefreshPendingBuffers() { + pending_staging_buffers_.clear(); + storage_cache_->OnRefresh(); + uniform_cache_->OnRefresh(); + query_resolve_cache_->OnRefresh(); + default_cache_->OnRefresh(); +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { + if (usage & WGPUBufferUsage_Storage) { + return *storage_cache_; + } else if (usage & WGPUBufferUsage_Uniform) { + return *uniform_cache_; + } else if (usage & WGPUBufferUsage_QueryResolve) { + return *query_resolve_cache_; + } else { + return *default_cache_; + } +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { + return GetCacheManager(wgpuBufferGetUsage(buffer)); +} + +std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { + return std::make_unique(context, storage_buffer_cache_mode, uniform_buffer_cache_mode, query_resolve_buffer_cache_mode); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h new file mode 100644 index 0000000000000..00febfbc29f1b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +enum class BufferCacheMode { + Disabled, + LazyRelease, + Simple, + Bucket +}; +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); + +// +// IBufferCacheManager is an interface for buffer cache management. +// +// By implementing this interface, we can have different buffer cache management strategies. +// Currently, we have 3 strategies: +// - Disabled: no cache. always allocate a new buffer and release it immediately after use. +// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. +// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. +// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. +// +class IBufferCacheManager { + public: + virtual ~IBufferCacheManager() = default; + + // calculate actual buffer size to allocate based on the requested size. + virtual size_t CalculateBufferSize(size_t request_size) = 0; + + // return a buffer if available in cache. otherwise empty. + virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) = 0; + + // register a newly created buffer + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; + + // release a buffer + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; + + // when a stream refresh is requested + virtual void OnRefresh() = 0; +}; + +// +// BufferManager manages operations on buffers. +// +class BufferManager { + public: + BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size); + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Release(WGPUBuffer buffer); + void Download(WGPUBuffer src, void* dst, size_t size); + void RefreshPendingBuffers(); + + private: + IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; + + WebGpuContext& context_; + std::unique_ptr storage_cache_; + std::unique_ptr uniform_cache_; + std::unique_ptr query_resolve_cache_; + std::unique_ptr default_cache_; + + std::vector pending_staging_buffers_; +}; + +class BufferManagerFactory { + public: + static std::unique_ptr Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + private: + BufferManagerFactory() {} +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 0000000000000..ce4f3e49611e2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { +ComputeContext::ComputeContext(OpKernelContext& kernel_context) + : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, + kernel_context_{kernel_context} { +} + +void ComputeContext::PushErrorScope() { + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation); + } +} + +Status ComputeContext::PopErrorScope() { + Status status{}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_ERROR(webgpu_context_.Wait( + webgpu_context_.Device().PopErrorScope( + wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) { + ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped."); + if (error_type == wgpu::ErrorType::NoError) { + return; + } + *status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message); + }, + &status))); + } + return status; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 0000000000000..b7ea8a58e232b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include + +#include "core/framework/execution_provider.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class Tensor; + +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(OpKernelContext& kernel_context); + + virtual ~ComputeContext() = default; + + // + // Get various information from the context. + // + + inline const wgpu::AdapterInfo& AdapterInfo() const { + return webgpu_context_.AdapterInfo(); + } + inline const wgpu::Limits& DeviceLimits() const { + return webgpu_context_.DeviceLimits(); + } + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; + } + + // + // Get the logger. + // + inline const logging::Logger& Logger() const { + return kernel_context_.Logger(); + } + + // + // Get input tensor. + // + template + inline const T* Input(int index) const { + return kernel_context_.Input(index); + } + + // + // Get input count. + // + inline int InputCount() const { + return kernel_context_.InputCount(); + } + + // + // Set output tensor. + // + template + inline Tensor* Output(int index, TensorShapeType&& shape) { + return kernel_context_.Output(index, std::forward(shape)); + } + + // + // Get output count. + // + inline int OutputCount() const { + return kernel_context_.OutputCount(); + } + + // + // Create CPU tensor. + // + template + Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Create GPU tensor. + // + template + Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + // + // Push error scope. + // + // This is useful only when "skip_validation" is not set. + // + void PushErrorScope(); + + // + // Pop error scope. + // + // This is useful only when "skip_validation" is not set. + // + Status PopErrorScope(); + + protected: + WebGpuContext& webgpu_context_; + OpKernelContext& kernel_context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc new file mode 100644 index 0000000000000..615ae11175782 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); + } else { + // copy from CPU to GPU + context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + } + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h new file mode 100644 index 0000000000000..f9949576aa60b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(const WebGpuContext& context) : context_{context} {}; + ~DataTransfer() {}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + + private: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc new file mode 100644 index 0000000000000..ee7c67ec24185 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/generator/range.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +template +Status Range::ComputeInternal(ComputeContext& context) const { + T start = context.Input(0)->Data()[0]; + T limit = context.Input(1)->Data()[0]; + T delta = context.Input(2)->Data()[0]; + + GSL_SUPPRESS(io.2) // Ignore warning about potential overflow in (limit - start) + int64_t n = static_cast(ceil((1.0 * (limit - start)) / delta)); + if (n <= 0) { + n = 0; + } + auto* output_tensor = context.Output(0, TensorShape{n}); + if (n == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow(n); + RangeProgram program{}; + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + output_size, + *reinterpret_cast(&start), + *reinterpret_cast(&delta), + }); + + return context.RunProgram(program); +} + +Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + + return Status(); +} + +#define WEBGPU_RANGE_KERNEL(TYPE) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Range, \ + kOnnxDomain, \ + 11, \ + TYPE, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 0) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Range); + +WEBGPU_RANGE_KERNEL(float) +WEBGPU_RANGE_KERNEL(int32_t) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.h b/onnxruntime/core/providers/webgpu/generator/range.h new file mode 100644 index 0000000000000..2f5812bb460ad --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Range : public WebGpuKernel { + public: + explicit Range(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +class RangeProgram : public Program { + public: + RangeProgram() : Program{"Range"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"start", ProgramUniformVariableDataType::Uint32}, + {"delta", ProgramUniformVariableDataType::Uint32}); +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc new file mode 100644 index 0000000000000..7f7a5707afa0a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webgpu/math/binary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + // check whether can use element-wise mode. + // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. + // In element-wise mode, no indices calculation is needed. + if (is_lhs_scalar_ || is_rhs_scalar_ || !is_broadcast_) { + // get A data + if (is_lhs_scalar_) { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("global_idx") << ";\n"; + } + + // get B data + if (is_rhs_scalar_) { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("global_idx") << ";\n"; + } + } else { + const auto& c_indices = shader.AddIndices("bcast_indices"); + // check whether can use vectorize mode. + // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode + // can be enabled. + // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values. + // Use indices helpers to calculate the offset of A and B. + if (vectorize_) { + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + // get A data + if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n"; + } + + // get B data + if (b.NumComponents() == 4) { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n"; + } + } else { + // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + + // get A data + shader.MainFunctionBody() << "let a = vec4(" + << a.GetByOffset("offset_a0") << ", " + << a.GetByOffset("offset_a1") << ", " + << a.GetByOffset("offset_a2") << ", " + << a.GetByOffset("offset_a3") << ");\n"; + // get B data + shader.MainFunctionBody() << "let b = vec4(" + << b.GetByOffset("offset_b0") << ", " + << b.GetByOffset("offset_b1") << ", " + << b.GetByOffset("offset_b2") << ", " + << b.GetByOffset("offset_b3") << ");\n"; + } + } + + shader.MainFunctionBody() << c.SetByOffset("global_idx", expression_); + return Status::OK(); +} + +Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { + auto lhs_tensor = context.Input(0); + auto rhs_tensor = context.Input(1); + const auto& lhs_shape = lhs_tensor->Shape(); + const auto& rhs_shape = rhs_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + auto output_tensor = context.Output(0, output_shape); + int64_t size = output_shape.Size(); + if (size == 0) { + return Status::OK(); + } + + bool is_broadcast = lhs_shape != rhs_shape; + bool is_lhs_scalar = lhs_shape.IsScalar(); + bool is_rhs_scalar = rhs_shape.IsScalar(); + + bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; + bool a_last_dim_divisible_by_4 = false; + bool b_last_dim_divisible_by_4 = false; + bool shared_dimension_divisible_by_4 = false; + size_t num_shared_dimension = 0; + if (!vectorize) { + // check whether vectorize can be enabled + a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; + b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; + if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { + vectorize = true; + } else { + size_t shared_dimension = 1; + for (size_t i = 1; i < output_shape.NumDimensions(); i++) { + size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + if (dimA == dimB) { + shared_dimension *= dimA; + num_shared_dimension++; + } else { + break; + } + } + if (shared_dimension % 4 == 0) { + shared_dimension_divisible_by_4 = true; + vectorize = true; + } + } + } + + uint32_t vec_size = gsl::narrow((size + 3) / 4); + BinaryElementwiseProgram program{kernel_name_, + expression_, + is_broadcast, + is_lhs_scalar, + is_rhs_scalar, + vectorize}; + program + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}); + + if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { + // Mode Element-wise + // cache hint: "E{is_a_scalar}{is_b_scalar}" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4}, + {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}}) + .CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar)); + } else if (vectorize) { + // reshape the dims to merge the shared dimension if available + bool need_reshape = shared_dimension_divisible_by_4 && num_shared_dimension > 1; + TensorShape reshaped_lhs_shape = need_reshape ? lhs_shape.Slice(0, lhs_shape.NumDimensions() - num_shared_dimension + 1) + : lhs_shape; + TensorShape reshaped_rhs_shape = need_reshape ? rhs_shape.Slice(0, rhs_shape.NumDimensions() - num_shared_dimension + 1) + : rhs_shape; + TensorShape reshaped_output_shape = need_reshape ? output_shape.Slice(0, output_shape.NumDimensions() - num_shared_dimension + 1) + : output_shape; + if (need_reshape) { + reshaped_lhs_shape[reshaped_lhs_shape.NumDimensions() - 1] = lhs_shape.SizeFromDimension(lhs_shape.NumDimensions() - num_shared_dimension); + reshaped_rhs_shape[reshaped_rhs_shape.NumDimensions() - 1] = rhs_shape.SizeFromDimension(rhs_shape.NumDimensions() - num_shared_dimension); + reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension); + } + + if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type}); + } + if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type}); + } + // Mode Vectorize broadcast + // cache hint: "V{a_rank};{b_rank};{output_rank}" + program + .AddIndices(reshaped_output_shape) + .AddIndices(reshaped_lhs_shape) + .AddIndices(reshaped_rhs_shape) + .CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(), + reshaped_rhs_shape.NumDimensions(), + reshaped_output_shape.NumDimensions()}, + ";")); + } else { + // Mode Broadcast + // cache hint: "B" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddIndices(output_tensor->Shape()) + .CacheHint("B"); + } + + return context.RunProgram(program); +} + +#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public BinaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_KERNEL_2(OP_TYPE, VERSION, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +WEBGPU_BINARY_IMPL(Add, "a + b") +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Div, "a / b") +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Mul, "a * b") +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Sub, "a - b") +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4(a), vec4(b)))") +WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL_2(Pow, 15, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Equal, "vec4(a == b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 7, 10, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 11, 12, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 13, 18, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Equal, 19, Equal, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Greater, "vec4(a > b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 7, 8, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 9, 12, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Greater, 13, Greater, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Less, "vec4(a < b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 7, 8, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 9, 12, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Less, 13, Less, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(GreaterOrEqual, "vec4(a >= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(GreaterOrEqual, 16, GreaterOrEqual, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(LessOrEqual, "vec4(a <= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h new file mode 100644 index 0000000000000..84cbcdf3244d8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BinaryElementwiseProgram final : public Program { + public: + BinaryElementwiseProgram(const std::string& kernel_name, + const std::string& expression, + const bool is_broadcast, + const bool is_lhs_scalar, + const bool is_rhs_scalar, + const bool vectorize) : Program{kernel_name}, + expression_{expression}, + is_broadcast_{is_broadcast}, + is_lhs_scalar_{is_lhs_scalar}, + is_rhs_scalar_{is_rhs_scalar}, + vectorize_{vectorize} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + bool is_broadcast_; + bool is_lhs_scalar_; + bool is_rhs_scalar_; + bool vectorize_; +}; + +class BinaryElementwise : public WebGpuKernel { + public: + BinaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + + private: + std::string kernel_name_; + std::string expression_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc new file mode 100644 index 0000000000000..8dcf63671092b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression_); + + return Status::OK(); +} + +Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + uint32_t vec_size = gsl::narrow((size + 3) / 4); + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (!cache_hint.empty()) { + program.CacheHint(cache_hint); + } + ORT_RETURN_IF_ERROR(ConfigureProgram(context, program)); + return context.RunProgram(program); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public UnaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : UnaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION_FROM, VERSION_TO, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + OP_TYPE_AND_CLASS_NAME); + +// +// math +// + +WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Neg, "-a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Neg, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Floor, "floor(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Floor, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Ceil, "ceil(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Ceil, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Reciprocal, "1.0/a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Reciprocal, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sqrt, "sqrt(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sqrt, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Log, "log(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Log, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) + +class HardSigmoid final : public UnaryElementwise { + public: + HardSigmoid(const OpKernelInfo& info) + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} { + // attr[0] is alpha, attr[1] is beta + info.GetAttrOrDefault("alpha", attr, 0.2f); + info.GetAttrOrDefault("beta", attr + 1, 0.5f); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({gsl::make_span(attr, 2)}); + return Status::OK(); + } + + protected: + float attr[2]; +}; + +WEBGPU_ELEMENTWISE_KERNEL(HardSigmoid, 6, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sin, "sin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cos, "cos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tan, "tan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Tan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asin, "asin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acos, "acos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atan, "atan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sinh, "sinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asinh, "asinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) + +#if __APPLE__ +// Metal returns 0 for values >= 1.0. +// Need custom impl to return +inf for 1.0 (by dividing 1 by 0), and NaN for > 1.0 (by dividing 0 by 0) +WEBGPU_ELEMENTWISE_IMPL(Atanh, + "select(" + " select(x_value_t(1.0), x_value_t(0.0), a > x_value_t(1.0)) / x_value_t(0.0)," + " atanh(a)," + " a < x_value_t(1.0))", + "", + ShaderUsage::UseValueTypeAlias) +#else +WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +#endif +WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Not, "!a") +WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(Not, 1) + +// No longer support Clip < opset 11 (where min and max are attributes) +// +// Use template class for "Clip" because the implementation is significantly different between float16 and float32 +template +class Clip final : public UnaryElementwise { + public: + Clip(const OpKernelInfo& info) + : UnaryElementwise{info, + "Clip", + std::is_same_v ? ClipF16Impl : ClipImpl, + "", ShaderUsage::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* clip_min_tensor = context.Input(1); + const auto* clip_max_tensor = context.Input(2); + + const T attr[] = {clip_min_tensor ? clip_min_tensor->Data()[0] + : std::numeric_limits::lowest(), + clip_max_tensor ? clip_max_tensor->Data()[0] + : std::numeric_limits::max()}; + if constexpr (std::is_same_v) { + // F16: stores span as a single float + float encoded_value = *reinterpret_cast(attr); + program.AddUniformVariable({encoded_value}); + } else { + static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); + // stores span as-is + program.AddUniformVariable({gsl::make_span(attr, 2)}); + } + return Status::OK(); + } + + // uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values. + // bitcast>(uniforms.attr)[0] is clip_min, bitcast>(uniforms.attr)[1] is clip_max + constexpr static const char ClipF16Impl[] = "clamp(a, vec4(bitcast>(uniforms.attr)[0]), vec4(bitcast>(uniforms.attr)[1]))"; + + // the size of element of uniforms.attr should be the same as x_element_t. use bitcast to convert between them + // uniforms.attr[0] is clip_min, uniforms.attr[1] is clip_max + constexpr static const char ClipImpl[] = "clamp(a, vec4(bitcast(uniforms.attr[0])), vec4(bitcast(uniforms.attr[1])))"; +}; +#define WEBGPU_CLIP_KERNEL(TYPE) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(Clip, kOnnxDomain, 13, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip); +WEBGPU_CLIP_KERNEL(float) +WEBGPU_CLIP_KERNEL(MLFloat16) + +// +// activation +// + +class LinearUnit : public UnaryElementwise { + public: + LinearUnit(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl, + float default_alpha) + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { + info.GetAttrOrDefault("alpha", &alpha_, default_alpha); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({alpha_}); + return Status::OK(); + } + + protected: + float alpha_; +}; + +#define WEBGPU_LU_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public LinearUnit { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) +WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) + +class Gelu : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderUsage::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); + } +}; + +WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) +WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h new file mode 100644 index 0000000000000..70fa81d21f95d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class UnaryElementwiseProgram final : public Program { + public: + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage) + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"attr", ProgramUniformVariableDataType::Float32}); // float type attribute(s) + // TODO: add u32/i32 attribute(s) if needed + + private: + std::string_view expression_; + std::string_view additional_impl_; + ShaderUsage additional_usage_; +}; + +// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch +// the std::string to std::string_view. This will avoid the cost of copying the string. + +class UnaryElementwise : public WebGpuKernel { + public: + UnaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl = "", + ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} + + protected: + std::string cache_hint; + + Status ComputeInternal(ComputeContext& context) const final; + virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { + program.AddUniformVariables({{}}); // empty for attribute(s) + return Status::OK(); + } + + private: + std::string kernel_name_; + std::string expression_; + std::string additional_impl_; + ShaderUsage additional_usage_; +}; + +constexpr const char ErfImpl[] = R"( +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +constexpr const char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: vec4) -> vec4 { + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; + +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +constexpr const char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; + +constexpr const char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +// default GELU expression, depending on ErfImpl +constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + +// fast GELU expression, depending on TanhImpl +constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc new file mode 100644 index 0000000000000..1ee771e945820 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -0,0 +1,155 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return gsl::narrow(axis < 0 ? axis + rank : axis); +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scale", ShaderUsage::UseUniform); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[j]" : ""; + std::string simpl1 = (simplified_) ? "" : " - mean * mean"; + std::string simpl2 = (simplified_) ? "" : " - mean"; + + shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") + << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = f32_val_t(0);\n" + << "var mean_square_vector = f32_val_t(0);\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = f32_val_t(x[h + offset]);\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = f32_val_t(x[j + offset]);\n" + << " let f32scale = f32_val_t(scale[j]);\n" + << " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + + return Status::OK(); +} + +template +Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + const auto* bias = context.Input(2); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = gsl::narrow(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = gsl::narrow((norm_size + components - 1) / components); + + const auto scale_size = scale->Shape().Size(); + const auto bias_size = (bias) ? bias->Shape().Size() : 0; + if (scale_size != norm_size || (bias && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + LayerNormProgram program{bias != nullptr, is_fp16, simplified}; + + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h new file mode 100644 index 0000000000000..17a9edbf4dd01 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class LayerNormProgram final : public Program { + public: + LayerNormProgram(bool has_bias, + bool is_fp16, + bool simplified) : Program{"LayerNorm"}, + has_bias_{has_bias}, + is_fp16_{is_fp16}, + simplified_{simplified} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"norm_count", ProgramUniformVariableDataType::Uint32}, + {"norm_size", ProgramUniformVariableDataType::Uint32}, + {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool has_bias_; + bool is_fp16_; + bool simplified_; +}; + +template +class LayerNorm final : public WebGpuKernel { + public: + LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + info.GetAttrOrDefault("stash_type", &stash_type_, 1); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc new file mode 100644 index 0000000000000..d1d4c242c4697 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramUniformVariableValue::ProgramUniformVariableValue() + : length{0}, data_type{} {} // representing an empty uniform variable + +ProgramUniformVariableValue::ProgramUniformVariableValue(float value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, &value, sizeof(float)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(uint32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, &value, sizeof(uint32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(int32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, &value, sizeof(int32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(MLFloat16 value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, &value, sizeof(MLFloat16)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, values.data(), sizeof(float), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, values.data(), sizeof(uint32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, values.data(), sizeof(int32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, values.data(), sizeof(MLFloat16), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, + const void* ptr, + size_t element_byte_size, + size_t length /* = 1 */) + : length{length}, data_type{data_type} { + ORT_ENFORCE(length > 0, "number of element of uniform variable must be greater than 0"); + + data.resize(length * element_byte_size); + memcpy(data.data(), ptr, length * element_byte_size); +} + +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { + os << ProgramUniformVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { + os << ProgramConstantDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) { + bool first = true; + if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { + os << "Type"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + if (!first) os << "|"; + os << "Rank"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + if (!first) os << "|"; + os << "Shape"; + first = false; + } + if (first) { + os << "None"; + } + + return os; +} + +#ifndef NDEBUG +constexpr std::string_view ProgramVariableDataTypeName[] = { + "f32", // Float32 + "f32x2", // Float32x2 + "f32x4", // Float32x4 + "f16", // Float16 + "f16x2", // Float16x2 + "f16x4", // Float16x4 + "i32", // Int32 + "i32x2", // Int32x2 + "i32x4", // Int32x4 + "u32", // Uint32 + "u32x2", // Uint32x2 + "u32x4", // Uint32x4 + "i64", // Int64 + "u64", // Uint64 + "boolx4", // Boolx4 + "u8x4", // Uint8x4 + "u8x8", // Uint8x8 + "u8x16", // Uint8x16 +}; +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { + os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} +#endif + +int NumberOfComponents(ProgramVariableDataType type) { + switch (type) { + case ProgramVariableDataType::Float32: + case ProgramVariableDataType::Int32: + case ProgramVariableDataType::Uint32: + case ProgramVariableDataType::Int64: + case ProgramVariableDataType::Uint64: + case ProgramVariableDataType::Float16: + return 1; + case ProgramVariableDataType::Float32x2: + case ProgramVariableDataType::Int32x2: + case ProgramVariableDataType::Uint32x2: + case ProgramVariableDataType::Float16x2: + return 2; + case ProgramVariableDataType::Float32x4: + case ProgramVariableDataType::Int32x4: + case ProgramVariableDataType::Uint32x4: + case ProgramVariableDataType::Float16x4: + case ProgramVariableDataType::Boolx4: + case ProgramVariableDataType::Uint8x4: + return 4; + case ProgramVariableDataType::Uint8x8: + return 8; + case ProgramVariableDataType::Uint8x16: + return 16; + default: + return -1; + } +} + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { + if (component == 1) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ProgramVariableDataType::Int64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ProgramVariableDataType::Uint64; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 2) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x2; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 4) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ProgramVariableDataType::Boolx4; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 8) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x8; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 16) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x16; + default: + return ProgramVariableDataType::InvalidType; + } + } else { + return ProgramVariableDataType::InvalidType; + } +} + +namespace { +TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { + ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, + "Cannot reduce shape ", shape.ToString(), " by component=", component); + TensorShape reduced_shape = shape; + reduced_shape[reduced_shape.NumDimensions() - 1] /= component; + return reduced_shape; +} +} // namespace + +ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramOutput::ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) + : name_{name}, + metadata_{metadata}, + dispatch_group_size_x_{0}, + dispatch_group_size_y_{0}, + dispatch_group_size_z_{0}, + workgroup_size_x_{0}, + workgroup_size_y_{0}, + workgroup_size_z_{0} { +} + +ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { + inputs_.emplace_back(input); + return *this; +} + +ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) { + inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { + outputs_.emplace_back(output); + return *this; +} + +ProgramBase& ProgramBase::AddOutputs(std::initializer_list outputs) { + outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { + return SetDispatchGroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y) { + return SetDispatchGroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { + dispatch_group_size_x_ = x; + dispatch_group_size_y_ = y; + dispatch_group_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) { + return SetWorkgroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y) { + return SetWorkgroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { + workgroup_size_x_ = x; + workgroup_size_y_ = y; + workgroup_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariable(ProgramUniformVariableValue&& variable) { + variables_.emplace_back(variable); + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariables(std::initializer_list variables) { + variables_.insert(variables_.end(), variables.begin(), variables.end()); + return *this; +} + +ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list overridable_constants) { + overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); + return *this; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h new file mode 100644 index 0000000000000..1562ec158b40a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.h @@ -0,0 +1,605 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { +class ShaderHelper; +class ComputeContext; +class WebGpuContext; + +// data type of uniform variable +enum class ProgramUniformVariableDataType { + Float32, + Float16, + Uint32, + Int32, +}; +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType); + +constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)}; + +constexpr std::string_view ProgramUniformVariableDataTypeName[] = {"f32", "f16", "u32", "i32"}; + +// represents a runtime value of a uniform variable +struct ProgramUniformVariableValue { + ProgramUniformVariableValue(); // representing an empty uniform variable + ProgramUniformVariableValue(float value); + ProgramUniformVariableValue(uint32_t value); + ProgramUniformVariableValue(int32_t value); + ProgramUniformVariableValue(MLFloat16 value); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + + size_t length; + ProgramUniformVariableDataType data_type; + std::vector data; + + private: + ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, const void* ptr, size_t element_byte_size, size_t length = 1); +}; + +// represents a uniform variable definition +struct ProgramUniformVariableDefinition { + constexpr ProgramUniformVariableDefinition(std::string_view name, ProgramUniformVariableDataType data_type) + : name{name}, data_type{data_type} {} + + std::string_view name; + ProgramUniformVariableDataType data_type; +}; + +// data type of constant +enum class ProgramConstantDataType { + Float32, + Float16, + Uint32, + Int32, + Bool +}; +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType); + +constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"}; + +// represents a constant in a program +struct ProgramConstant { + constexpr ProgramConstant(std::string_view name, float value) : name{name}, type{ProgramConstantDataType::Float32}, f32{value} {} + constexpr ProgramConstant(std::string_view name, uint32_t value) : name{name}, type{ProgramConstantDataType::Uint32}, u32{value} {} + constexpr ProgramConstant(std::string_view name, int32_t value) : name{name}, type{ProgramConstantDataType::Int32}, i32{value} {} + constexpr ProgramConstant(std::string_view name, MLFloat16 value) : name{name}, type{ProgramConstantDataType::Float16}, f16{value} {} + constexpr ProgramConstant(std::string_view name, bool value) : name{name}, type{ProgramConstantDataType::Bool}, boolean{value} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; +}; + +// represents a runtime value of an overridable constant +struct ProgramOverridableConstantValue { + constexpr ProgramOverridableConstantValue() : type{}, u32{}, has_value{false} {} // representing not overriding + constexpr ProgramOverridableConstantValue(float value) : type{ProgramConstantDataType::Float32}, f32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(uint32_t value) : type{ProgramConstantDataType::Uint32}, u32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(int32_t value) : type{ProgramConstantDataType::Int32}, i32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(MLFloat16 value) : type{ProgramConstantDataType::Float16}, f16{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(bool value) : type{ProgramConstantDataType::Bool}, boolean{value}, has_value{true} {} + + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_value; +}; + +// represents an overridable constant definition. may or may not have a default value. +struct ProgramOverridableConstantDefinition { + constexpr ProgramOverridableConstantDefinition(std::string_view name, ProgramConstantDataType type) + : name{name}, type{type}, u32{}, has_default_value{false} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, float value) + : name{name}, type{ProgramConstantDataType::Float32}, f32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, uint32_t value) + : name{name}, type{ProgramConstantDataType::Uint32}, u32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, int32_t value) + : name{name}, type{ProgramConstantDataType::Int32}, i32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, MLFloat16 value) + : name{name}, type{ProgramConstantDataType::Float16}, f16{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, bool value) + : name{name}, type{ProgramConstantDataType::Bool}, boolean{value}, has_default_value{true} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_default_value; +}; + +// represents whether the program shader depends on the type, rank, or shape of an input/output tensor +enum class ProgramTensorMetadataDependency : int { + None = 0, + Type = 1, + Rank = 2, + Shape = 4, + TypeAndRank = Type | Rank, + TypeAndShape = Type | Shape, +}; +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency); + +inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a | (int&)b); +} +inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a & (int&)b); +} +inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b); +} +inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); +} + +constexpr SafeInt WORKGROUP_SIZE = 64; + +// data type of variable +// +// this is not a full list of all possible data types in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableDataType { + InvalidType = -1, + Float32, + Float32x2, + Float32x4, + Float16, + Float16x2, + Float16x4, + Int32, + Int32x2, + Int32x4, + Uint32, + Uint32x2, + Uint32x4, + Int64, + Uint64, + Boolx4, + Uint8x4, + Uint8x8, + Uint8x16 +}; +#ifndef NDEBUG +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); +#endif + +int NumberOfComponents(ProgramVariableDataType type); + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); + +struct ProgramInput { + ProgramInput(const Tensor* tensor); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + const Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +enum class ValidationMode { + Disabled = 0, + WGPUOnly, + Basic, + Full +}; + +namespace details { +class ProgramWrapper; +} + +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; + +class ProgramBase { + public: + // + // chain-style methods for setting properties + // + + // set the cache hint for the program + template + ProgramBase& CacheHint(T&&... hints) { + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(hints)...), "|"); + return *this; + } + + // add a program input + ProgramBase& AddInput(ProgramInput&& input); + // add multiple program inputs + ProgramBase& AddInputs(std::initializer_list inputs); + // add a program output + ProgramBase& AddOutput(ProgramOutput&& output); + // add multiple program outputs + ProgramBase& AddOutputs(std::initializer_list outputs); + // add a program variable for indices + ProgramBase& AddIndices(const TensorShape& shape); + // add a program variable for indices + ProgramBase& AddIndices(TensorShape&& shape); + + // set the size of dispatch groups. Y and Z are 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x); + // set the size of dispatch groups. Z is 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y); + // set the size of dispatch groups. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the size of a workgroup grid. Y and Z are 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x); + // set the size of a workgroup grid. Z is 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y); + // set the size of a workgroup grid. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + + // add a uniform variable. + // + // the specified uniform variable should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariable(ProgramUniformVariableValue&& variable); + // add multiple uniform variables. + // + // the specified uniform variables should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariables(std::initializer_list variables); + + // set the overridable constants + // + // the specified overridable constants should match the overridable constant definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. + ProgramBase& SetOverridableConstants(std::initializer_list overridable_constants); + + // + // shader code generation + // + + virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; + + // + // Properties Getters + // + + inline const std::string& Name() const { return name_; } + inline const ProgramMetadata& Metadata() const { return metadata_; } + inline const std::string& CacheHint() const { return cache_hint_; } + inline const std::vector& Inputs() const { return inputs_; } + inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Indices() const { return indices_; } + inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } + inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } + inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } + inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; } + inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; } + inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; } + inline const std::vector& UniformVariables() const { return variables_; } + inline const std::vector& OverridableConstants() const { return overridable_constants_; } + + protected: + virtual ~ProgramBase() = default; + + private: + // Make the constructor private to prevent direct instantiation or inheritance from this class + // Use the Program template class as base class to create a new program class + explicit ProgramBase(std::string_view name, ProgramMetadata&& metadata); + + std::string name_; + ProgramMetadata metadata_; + + std::string cache_hint_; + std::vector inputs_; + std::vector outputs_; + std::vector indices_; + + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + uint32_t workgroup_size_x_; + uint32_t workgroup_size_y_; + uint32_t workgroup_size_z_; + + std::vector variables_; + std::vector overridable_constants_; + + friend class details::ProgramWrapper; +}; + +namespace details { +// class ProgramWrapper is for accessing private constructor of ProgramBase. +// only ProgramWrapper can access the constructor of ProgramBase because ProgramWrapper is the only friend class of +// ProgramBase. This design is used to prevent direct instantiation or inheritance from ProgramBase. +class ProgramWrapper : public ProgramBase { + protected: + template + ProgramWrapper(Args&&... args) : ProgramBase{std::forward(args)...} {} +}; + +#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK) +#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" +#endif + +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template ::value && /* - is a const std::array */ \ + std::is_const_v && /* - has "const" modifier */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ + static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value + +// the following template class checks whether the type is a const std::array +template +struct is_const_std_array : std::false_type {}; +template +struct is_const_std_array> : std::true_type {}; + +// the following template class checks whether certain static members exist in the derived class (SFINAE) +template +class DerivedProgramClassTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition); +}; + +// compile-time tests for the type check +// +// TODO: move this to test folder +namespace test { + +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + +struct TestClass_Empty {}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_0 { + int b; +}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_1 { + int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_2 { + const int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_0 { + const int a[2]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_1 { + static constexpr int a[] = {0}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_2 { + static int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_3 { + static const int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_0 { + std::array a = {1}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_1 { + static constexpr std::array a = {1, 2}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_2 { + static const std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_3 { + static constexpr const std::array a = {1, 2, 3, 4}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_4 { + static std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +} // namespace test + +#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK + +} // namespace details + +template +class Program : public details::ProgramWrapper { + public: + template + Program(Args&&... args) : details::ProgramWrapper{std::forward(args)..., GetMetadata()} {} + + static ProgramMetadata GetMetadata() { + ProgramMetadata metadata; + if constexpr (details::DerivedProgramClassTypeCheck::has_constants) { + constexpr const ProgramConstant* ptr = T::constants.data(); + constexpr size_t len = T::constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_constants_with_correct_type, + "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants."); + + metadata.constants = {ptr, len}; + } else { + metadata.constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_overridable_constants) { + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data(); + constexpr size_t len = T::overridable_constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type, + "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + + metadata.overridable_constants = {ptr, len}; + } else { + metadata.overridable_constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_uniform_variables) { + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data(); + constexpr size_t len = T::uniform_variables.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type, + "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables."); + + metadata.uniform_variables = {ptr, len}; + } else { + metadata.uniform_variables = {}; + } + + return metadata; + } +}; + +namespace details { +// helper function to convert a C-style array to std::array +// +// This is basically the same as std::to_array in C++20. +// +template +constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence) -> std::array, N> { + return {{arr[Idx]...}}; +} + +template +constexpr auto _to_std_array(T (&arr)[N]) -> std::array, N> { + return _to_std_array_impl(arr, std::make_index_sequence{}); +} + +// helper function to concatenate a std::array and a C-style array to a std::array +// +template +constexpr std::array, L + R> _concat2_impl(const std::array& lhs, + T (&rhs)[R], + std::index_sequence, + std::index_sequence) { + return {{lhs[IdxL]..., rhs[IdxR]...}}; +} + +template +constexpr std::array, L + R> _concat2(const std::array& lhs, T (&rhs)[R]) { + return _concat2_impl(lhs, rhs, std::make_index_sequence{}, std::make_index_sequence{}); +} + +} // namespace details +#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_to_std_array(identifier##_own) + +#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_concat2(BASE::identifier, identifier##_own) + +#define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(constants, onnxruntime::webgpu::ProgramConstant, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(constants, onnxruntime::webgpu::ProgramConstant, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ + WEBGPU_PROGRAM_DEFINE_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, BASE, __VA_ARGS__) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc new file mode 100644 index 0000000000000..a5c21563dbfcd --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/program_cache_key.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +// macro "D" - append to the ostream only in debug build +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + +namespace { +// append the info of an input or output to the cachekey +void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, + bool& first) { + if (first) { + first = false; + } else { + ss << '|'; + } + + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { +#ifndef NDEBUG // if debug build + ss << var_type; +#else + ss << static_cast(var_type); +#endif + ss << ';'; + } + + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + ss D("Dims=") << tensor.Shape().ToString(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Rank=") << tensor.Shape().NumDimensions(); + } +} +} // namespace + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { + SS(ss, kStringInitialSizeCacheKey); + + // final key format: + // =[]:::: + // + // = ||... + // = ,, + // = + // = ||... + // = + // = ||... + // = ; + ss << program.Name(); + + // append custom cache hint if any + if (auto& hint = program.CacheHint(); !hint.empty()) { + ss << '[' D("CacheHint=") << hint << ']'; + } + + // append workgroup size if overridden + if (auto x = program.WorkgroupSizeX(), y = program.WorkgroupSizeY(), z = program.WorkgroupSizeZ(); + x != 0 || y != 0 || z != 0) { + ss << ":" D("WorkgroupSize="); + // only append non-zero values. zero values are considered as use default + if (x > 0) { + ss << x; + } + ss << ","; + if (y > 0) { + ss << y; + } + ss << ","; + if (z > 0) { + ss << z; + } + } + + ss << ":" D("DispatchDim=") << (is_1d_dispatch ? "1" : "3"); + ss << ":" D("UniformSizes="); + bool first = true; + for (const auto& uniform : program.UniformVariables()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if (uniform.length > 0) { + ss << uniform.length; + } + } + + ss << ":" D("Inputs="); + first = true; + for (const auto& input : program.Inputs()) { + AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); + } + + ss << ":" D("Outputs="); + first = true; + for (const auto& output : program.Outputs()) { + AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); + } + + return SS_GET(ss); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/shape_op.h b/onnxruntime/core/providers/webgpu/program_cache_key.h similarity index 51% rename from onnxruntime/core/codegen/mti/tensor/shape_op.h rename to onnxruntime/core/providers/webgpu/program_cache_key.h index 67ee2de50eca9..22ba19ebd0f25 100644 --- a/onnxruntime/core/codegen/mti/tensor/shape_op.h +++ b/onnxruntime/core/providers/webgpu/program_cache_key.h @@ -2,13 +2,15 @@ // Licensed under the MIT License. #pragma once + #include -#include + +#include "core/providers/webgpu/program.h" namespace onnxruntime { -namespace tvm_codegen { +namespace webgpu { -tvm::Tensor Shape(const tvm::Tensor& X, const std::string& name = "shape"); +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch); -} // namespace tvm_codegen +} // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc new file mode 100644 index 0000000000000..109bac34d6503 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/common.h" + +#include "core/common/common.h" +#include "core/common/logging/logging.h" + +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) + : name{program.Name()}, + compute_pipeline{compute_pipeline}, + shape_uniform_ranks{shape_uniform_ranks} {} + +Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { + ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); + + auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; + if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { + auto size = static_cast(x) * static_cast(y) * static_cast(z); + uint32_t dispatch_avg = gsl::narrow(std::ceil(std::sqrt(size))); + if (dispatch_avg > limit_per_dimension) { + dispatch_avg = gsl::narrow(std::ceil(std::cbrt(size))); + ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); + x = y = z = dispatch_avg; + } else { + x = y = dispatch_avg; + z = 1; + } + } + return Status::OK(); +} + +Status ProgramManager::Build(const ProgramBase& program, + const ProgramMetadata& program_metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const { + ShaderHelper shader_helper{program, + program_metadata, + device_, + limits_, + normalized_dispatch_x, + normalized_dispatch_y, + normalized_dispatch_z}; + ORT_RETURN_IF_ERROR(shader_helper.Init()); + + ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); + + // code is a large std::string that contains the final shader code + std::string code; + ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); + + LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] Start ===\n\n" + << code + << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] End ===\n"; + + wgpu::ShaderModuleWGSLDescriptor wgsl_descriptor{}; + wgsl_descriptor.code = code.c_str(); + + wgpu::ShaderModuleDescriptor descriptor{}; + descriptor.nextInChain = &wgsl_descriptor; + + auto shader_module = device_.CreateShaderModule(&descriptor); + + // TODO: a new cache hierarchy for constants. + // + // Explaination: + // Currently, we use Uniforms for dynamic data. This helps to reduce the number of program artifacts. + // + // "dynamic data" here means the data the determined at runtime, such as the shape of the input tensor. + // + // However, some programs may not necessarily depend on dynamic data. For example, "Clip" may depend on the value of "min" and "max". + // We are using uniforms for the value of "min" and "max" in the current implementation, but usually "min" and "max" are determined + // earlier because they are either from Attributes or from the initializers of the model. + // + // Questions: + // - can we use one instance of ShaderModule to create multiple ComputePipeline? + // - is there any benefit to do so compared to the current implementation? + // + + // process overridable constants if available + size_t constant_count = program.OverridableConstants().size(); + + // making a copy of the constant names is required because they are stored as std::string_view in the program + // metadata. A value of std::string_view is not guaranteed to be a C-stlye string (null-terminated) and hence + // cannot be used directly in the WebGPU API (which expects a const char*). + std::vector constant_names; + constant_names.reserve(constant_count); + std::vector constant_entries; + constant_entries.reserve(constant_count); + for (size_t i = 0; i < constant_count; ++i) { + const auto& constant_override = program.OverridableConstants()[i]; + const auto& constant_def = program_metadata.overridable_constants[i]; + + if (constant_override.has_value) { + double value = 0; + switch (constant_override.type) { + case ProgramConstantDataType::Bool: + value = constant_override.boolean ? 1 : 0; + break; + case ProgramConstantDataType::Float16: + // convert f16(MLFloat16) -> f32(float) -> f64(double) + // because the value of a constant must be a double in WebGPU API, it is expensive to use f16 overridable constants. + value = constant_override.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + value = constant_override.f32; + break; + case ProgramConstantDataType::Int32: + value = constant_override.i32; + break; + case ProgramConstantDataType::Uint32: + value = constant_override.u32; + break; + } + + const auto& name_string = constant_names.emplace_back(constant_def.name); + wgpu::ConstantEntry entry{}; + entry.key = name_string.c_str(); + entry.value = value; + constant_entries.push_back(std::move(entry)); + } + } + + wgpu::ProgrammableStageDescriptor compute_stage{}; + compute_stage.module = shader_module; + compute_stage.entryPoint = "main"; + if (!constant_entries.empty()) { + compute_stage.constants = constant_entries.data(); + compute_stage.constantCount = constant_entries.size(); + } + + wgpu::ComputePipelineDescriptor pipeline_descriptor{}; + pipeline_descriptor.compute = compute_stage; +#ifndef NDEBUG // if debug build + pipeline_descriptor.label = program.Name().c_str(); +#endif + + compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor); + + return Status(); +} + +const ProgramArtifact* ProgramManager::Get(const std::string& key) const { + auto result = programs_.find(key); + if (result != programs_.end()) { + return &result->second; + } + + return nullptr; +} + +const ProgramArtifact* ProgramManager::Set(const std::string& key, ProgramArtifact&& program) { + return &(programs_.emplace(key, std::move(program)).first->second); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h new file mode 100644 index 0000000000000..eded1cfa17970 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { + +class ProgramArtifact { + public: + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks); + + const std::string name; + const wgpu::ComputePipeline compute_pipeline; + const std::vector shape_uniform_ranks; + + ProgramArtifact(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = delete; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); +}; + +class ProgramManager { + public: + ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {} + + Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; + + Status Build(const ProgramBase& program, + const ProgramMetadata& metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const; + const ProgramArtifact* Get(const std::string& key) const; + const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); + + private: + std::unordered_map programs_; + const wgpu::Device& device_; + const wgpu::Limits& limits_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc new file mode 100644 index 0000000000000..5685494556248 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/string_utils.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderHelper::ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z) + : device_{device}, + limits_{limits}, + dispatch_group_size_x_{dispatch_group_size_x}, + dispatch_group_size_y_{dispatch_group_size_y}, + dispatch_group_size_z_{dispatch_group_size_z}, + program_{program}, + program_metadata_{program_metadata}, + additional_implementation_ss_{&additional_implementation_}, + body_ss_{&body_} {} + +Status ShaderHelper::Init() { + // dispatch group size is normalized so no need to validate it here + + // validate workgroup size + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); + + ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && + workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && + workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, + "Workgroup size exceeds the maximum allowed size [", + limits_.maxComputeWorkgroupSizeX, ", ", + limits_.maxComputeWorkgroupSizeY, ", ", + limits_.maxComputeWorkgroupSizeZ, "]"); + + ORT_RETURN_IF_NOT(workgroup_size_x * workgroup_size_y * workgroup_size_z <= limits_.maxComputeInvocationsPerWorkgroup, + "Workgroup size exceeds the maximum allowed invocations ", limits_.maxComputeInvocationsPerWorkgroup); + + // init body string stream + bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + body_.reserve(4096); + additional_implementation_.reserve(1024); + + // append header for main function so it is ready for user to append main function body + body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(local_invocation_id) local_id : vec3"; + if (!is_1d_dispatch) { + body_ss_ << ",\n" + " @builtin(num_workgroups) num_workgroups : vec3"; + } + body_ss_ << ") {\n"; + if (is_1d_dispatch) { + body_ss_ << " let global_idx = global_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; + } else { + body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } + + return Status::OK(); +} + +const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { + const size_t input_index = input_vars_.size(); + ORT_ENFORCE(input_index < program_.Inputs().size(), + "Too many inputs in the program (", program_.Inputs().size(), ")"); + + const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape + : program_.Inputs()[input_index].tensor->Shape(); + return AddVariableImpl(true, name, usage, dims); +} + +const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) { + const size_t output_index = output_vars_.size(); + ORT_ENFORCE(output_index < program_.Outputs().size(), + "Too many outputs in the program (", program_.Outputs().size(), ")"); + + const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape + : program_.Outputs()[output_index].tensor->Shape(); + return AddVariableImpl(false, name, usage, dims); +} + +const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) { + const size_t indices_index = indices_vars_.size(); + return *indices_vars_.emplace_back( + std::make_unique(name, + ProgramVariableDataType::InvalidType, + use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None, + program_.Indices()[indices_index])); +} + +#ifndef NDEBUG // if debug build +namespace { +// Validate if the tensor element type matches the program variable data type +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Float32x2 || + var_type == ProgramVariableDataType::Float32x4, + "Unexpected program variable type ", int(var_type), " for float32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || + var_type == ProgramVariableDataType::Float16x2 || + var_type == ProgramVariableDataType::Float16x4, + "Unexpected program variable type ", int(var_type), " for float16 tensor"); + + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Int32x2 || + var_type == ProgramVariableDataType::Int32x4, + "Unexpected program variable type ", int(var_type), " for int32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Uint32x2 || + var_type == ProgramVariableDataType::Uint32x4, + "Unexpected program variable type ", int(var_type), " for uint32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int64, + "Unexpected program variable type ", int(var_type), " for int64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint64, + "Unexpected program variable type ", int(var_type), " for uint64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4, + "Unexpected program variable type ", int(var_type), " for bool tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 || + var_type == ProgramVariableDataType::Uint8x8 || + var_type == ProgramVariableDataType::Uint8x16, + "Unexpected program variable type ", int(var_type), " for uint8 tensor"); + break; + default: + ORT_RETURN_IF(true, "Unsupported data type: ", element_type); + // todo: add int4/uint4 + } + return Status::OK(); +} + +// Validate if the number of components and override shape match the original shape +Status ValidateVariableShape(const TensorShape& origin_shape, + bool use_override_shape, + const TensorShape& override_shape, + int num_components) { + if (use_override_shape) { + // if override shape specified, assert override_size == ceil( origin_size / 4 ) + ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), + "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); + } + + return Status::OK(); +} + +// Validate if the dependency and variable usage match +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderUsage usage, bool is_input) { + bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; + + // if dependency is already set for shape, it is no need to set for rank. + ORT_RETURN_IF(dependency_rank && dependency_shape, + "Dependency cannot set for both \"Rank\" and \"Shape\"."); + + // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. + ORT_RETURN_IF(dependency_shape && (usage & ShaderUsage::UseUniform), + "Dependency is set for \"Shape\", using uniform for shape is not allowed."); + + // for input variable, check is more strict. + // this is because usually output shape is determined by the existing information, which is already part of the shader cache. + if (is_input) { + // if dependency is not set for type, should not use type alias for element and value. + // storage type is always used. so setting not depending on type is at user's own risk. + ORT_RETURN_IF(!dependency_type && (usage & (ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias)), + "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); + + // if dependency is not set for rank and shape, the shader should not use shape and stride. + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderUsage::UseShapeAndStride), + "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); + } + + return Status::OK(); +} +} // namespace + +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + + return Status::OK(); +} + +#endif // NDEBUG + +const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims) { + ORT_ENFORCE(input_vars_.size() + output_vars_.size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + + ProgramVariableDataType type = ProgramVariableDataType::InvalidType; + auto& vars = is_input ? input_vars_ : output_vars_; + + if (is_input) { + const auto& input = program_.Inputs()[vars.size()]; + type = input.var_type; + } else { + const auto& output = program_.Outputs()[vars.size()]; + type = output.var_type; + } + + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); + return *var; +} + +Status ShaderHelper::ValidateShapeForInputs() const { + // Validate input as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size()); + for (size_t i = 0; i < input_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate input shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars_[i])); +#endif + + // check input dependencies with actual usages. + auto usage = input_vars_[i]->usage_; + auto dependency = program_.Inputs()[i].dependency; + bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars_[i]->rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, use a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateShapeForOutputs() const { + // Validate output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size()); + + for (size_t i = 0; i < output_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate output shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars_[i])); +#endif + + // check output dependencies with actual usages. + auto usage = output_vars_[i]->usage_; + auto dependency = program_.Outputs()[i].dependency; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateIndices() const { + ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(), + "Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size()); + + return Status::OK(); +} + +Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { + SS(ss, kStringInitialSizeShaderSourceCode); + + // + // Section feature enabling + // + if (std::any_of(program_.Inputs().begin(), + program_.Inputs().end(), + [](const ProgramInput& input) { + return input.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + }) || + std::any_of(program_.Outputs().begin(), + program_.Outputs().end(), + [](const ProgramOutput& output) { + return output.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + })) { + ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } + } + if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { + ss << "enable subgroups;\n"; + } + + // + // Section constants + // + ss << "const workgroup_size_x: u32 = " << (program_.WorkgroupSizeX() == 0 ? uint32_t(WORKGROUP_SIZE) : program_.WorkgroupSizeX()) + << ";\nconst workgroup_size_y: u32 = " << (program_.WorkgroupSizeY() == 0 ? uint32_t(1) : program_.WorkgroupSizeY()) + << ";\nconst workgroup_size_z: u32 = " << (program_.WorkgroupSizeZ() == 0 ? uint32_t(1) : program_.WorkgroupSizeZ()) + << ";\n"; + + for (const auto& constant : program_metadata_.constants) { + ss << "const " << constant.name << ": " << constant.type << " = "; + WriteConstantValue(ss, constant); + ss << ";\n"; + } + + size_t override_constant_count = program_metadata_.overridable_constants.size(); + for (size_t i = 0; i < override_constant_count; ++i) { + // size and type are previously checked to match + const auto& constant_def = program_metadata_.overridable_constants[i]; + const auto& constant_override = program_.OverridableConstants()[i]; + + ss << "override " << constant_def.name << ": " << constant_def.type << " = "; + if (constant_override.has_value) { + WriteConstantValue(ss, constant_override); + } else { + WriteConstantValue(ss, constant_def); + } + ss << ";\n"; + } + + // + // Input/output variables + // + size_t variable_count = 0; + for (const auto& input : input_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; + } + for (const auto& output : output_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; + } + + // + // uniform variables + // + + // store shape uniform ranks in shape_uniform_ranks + bool use_any_shape_uniform = false; + ORT_ENFORCE(shape_uniform_ranks.size() == 0); + shape_uniform_ranks.reserve(input_vars_.size() + output_vars_.size() + indices_vars_.size()); + + for (const auto& input : input_vars_) { + bool use_uniform = (input->usage_ & ShaderUsage::UseUniform) && + (input->usage_ & ShaderUsage::UseShapeAndStride) && + input->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); + } + for (const auto& output : output_vars_) { + bool use_uniform = (output->usage_ & ShaderUsage::UseUniform) && + (output->usage_ & ShaderUsage::UseShapeAndStride) && + output->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); + } + for (const auto& indices : indices_vars_) { + bool use_uniform = (indices->usage_ & ShaderUsage::UseUniform) && + (indices->usage_ & ShaderUsage::UseShapeAndStride) && + indices->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? indices->rank_ : 0); + } + + if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {"; + + // lambda append_uniform is used to append one uniform variable to the uniform struct + auto append_uniform = [&ss, &first](std::string_view name, ProgramUniformVariableDataType data_type, size_t length) { + if (length == 0) { + return; + } + + if (first) { + first = false; + } else { + ss << ","; + } + + auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; + ss << "\n " << alignment << name << ": "; + if (length > 4) { + if (data_type == ProgramUniformVariableDataType::Float16) { + size_t array_size = (length + 7) / 8; + ss << "array, " << array_size << ">"; + } else { + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; + } + } else if (length > 1) { + ss << "vec" << length << "<" << data_type << ">"; + } else { + ss << data_type; + } + }; + + for (const auto& input : input_vars_) { + const size_t rank = input->rank_; + if (rank > 0 && (input->usage_ & ShaderUsage::UseUniform) && (input->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = input->name_ + "_shape"; + std::string stride = input->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& output : output_vars_) { + const size_t rank = output->rank_; + if (rank > 0 && (output->usage_ & ShaderUsage::UseUniform) && (output->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = output->name_ + "_shape"; + std::string stride = output->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& indices : indices_vars_) { + const size_t rank = indices->rank_; + if (rank > 0 && (indices->usage_ & ShaderUsage::UseUniform) && (indices->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = indices->name_ + "_shape"; + std::string stride = indices->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + append_uniform(uniform_def.name, uniform_def.data_type, uniform_value.length); + } + + ss << "\n};\n" + "@group(0) @binding(" + << variable_count << ") var uniforms: Uniforms;\n"; + } + + // + // Indices helper + // + ss << "\n"; + for (const auto& var : input_vars_) { + var->Impl(ss); + } + for (const auto& var : output_vars_) { + var->Impl(ss); + } + for (const auto& var : indices_vars_) { + var->Impl(ss); + } + ss << "\n"; + + // + // Additional Implementation + // + ss << additional_implementation_; + + // + // Main Function Body + // + ss << body_; + ss << "\n" + "}\n"; + + code = SS_GET(ss); + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h new file mode 100644 index 0000000000000..a4b96edc63c74 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/string_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class ShaderHelper final { + // The content of a shader code is composed of the following parts: + // + // ** + // ** section: feature sets definition + // ** + // // this sections enable features like "enable f16;". need to be defined at the beginning of the shader. + // + // ** + // ** section: constants and overridable constants + // ** + // // this section defines constants and overridable constants. + // - constants are defined as "const a:f32 = 1.0;". It's hard coded in the shader. + // - overridable constants are defined as "override a:f32 = 1.0;" (may override or not) + // or "override b:u32;" (must override) + // the value can be overriden by pipeline creation config. + // + // ** + // ** section: inputs and outputs + // ** + // // this section defines input and output variables. + // user can call shader_helper.AddVariable() to add input and output variables. + // + // ** + // ** section: uniforms + // ** + // // this section defines uniform type and variables. + // + // ** + // ** section: indices helper generated utility functions + // ** + // // this section defines utility functions to calculate indices. + // + // ** + // ** section: additional implementation + // ** + // // this section contains additional implementation provided by the user. + // user can call shader_helper.AppendImplementation() to append additional implementation. + // + // ** + // ** section: main function + // ** + // // this section contains the main function of the shader. + // user can call shader_helper.MainFunctionBody() to set the main function body. + // + + public: + ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z); + + Status Init(); + + // Add an input variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddInput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an output variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddOutput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an indices variable to the shader. + const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); + + // Get the string stream for additional implementation code to the shader. + inline OStringStream& AdditionalImplementation() { + return additional_implementation_ss_; + } + + // Get the string stream for the main function body of the shader. + inline OStringStream& MainFunctionBody() { + return body_ss_; + } + + std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { + return MakeStringWithClassicLocale(" if (global_idx >= ", size, ") { return; }\n"); + } + + private: + template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} + void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const { + switch (constant.type) { + case ProgramConstantDataType::Float16: + ss << constant.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + ss << constant.f32; + break; + case ProgramConstantDataType::Int32: + ss << constant.i32; + break; + case ProgramConstantDataType::Uint32: + ss << constant.u32; + break; + case ProgramConstantDataType::Bool: + ss << (constant.boolean ? "true" : "false"); + break; + default: + ORT_THROW("Invalid constant type", constant.type); + } + } + + const ShaderVariableHelper& AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims); + +#ifndef NDEBUG // if debug build + Status ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const; +#endif + + Status ValidateShapeForInputs() const; + Status ValidateShapeForOutputs() const; + Status ValidateIndices() const; + + // Generate source code. + // + // This function: + // - performs validation if neccessary, + // - appends the ranks for variables to the shape_uniform_ranks. + // (The rank value is zero if no uniform is needed for the variable.) + // - generates the final source code. + // + // \param code The generated full WGSL source code. + // \param shape_uniform_ranks The ranks for variables that need a uniform for the shape. + // + Status GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const; + friend class ProgramManager; + + const wgpu::Device& device_; + const wgpu::Limits& limits_; + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + const ProgramBase& program_; + const ProgramMetadata& program_metadata_; + + std::vector> input_vars_; + std::vector> output_vars_; + std::vector> indices_vars_; + std::string additional_implementation_; + OStringStream additional_implementation_ss_; + std::string body_; + OStringStream body_ss_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc new file mode 100644 index 0000000000000..15020b801c97d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/providers/webgpu/shader_variable.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "vec2", // Int64 + "vec2", // Uint64 + "u32", // Boolx4 + "u32", // Uint8x4 + "vec2", // Uint8x8 + "vec4", // Uint8x16 +}; +constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY); + +constexpr static const std::string_view VALUE_TYPE_ARRAY[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "i32", // Int64 (trancated to i32) + "u32", // Uint64 (trancated to u32) + "vec4", // Boolx4 + "u32", // Uint8x4 (u32 as 4 elements of uint8) + "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) + "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) +}; +constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY); + +constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = { + "f32", // Float32 + "f32", // Float32x2 + "f32", // Float32x4 + "f16", // Float16 + "f16", // Float16x2 + "f16", // Float16x4 + "i32", // Int32 + "i32", // Int32x2 + "i32", // Int32x4 + "u32", // Uint32 + "u32", // Uint32x2 + "u32", // Uint32x4 + "i32", // Int64 + "u32", // Uint64 + "bool", // Boolx4 + "u32", // Uint8x4 + "u32", // Uint8x8 + "u32", // Uint8x16 +}; +constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY); + +inline std::string GetIndicesType(int rank) { + return rank < 2 ? "u32" + : (rank <= 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); +} + +} // namespace + +ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : name_(name), + type_(type), + num_components_{NumberOfComponents(type)}, + rank_{gsl::narrow(dims.NumDimensions())}, + dims_{dims}, + usage_(usage), + indices_type_{GetIndicesType(rank_)}, + value_type_alias_{name_ + "_value_t"}, + element_type_alias_{name_ + "_element_t"}, + indices_type_alias_{name_ + "_indices_t"} {} + +ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : ShaderIndicesHelper{name, type, usage, dims} { + ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); + ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); +} + +void ShaderIndicesHelper::Impl(std::ostream& ss) const { + // Start generating code + + const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + + // Types + if (usage_ & ShaderUsage::UseValueTypeAlias) { + SS_APPEND(ss, "alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); + } + if (usage_ & ShaderUsage::UseIndicesTypeAlias) { + SS_APPEND(ss, "alias ", indices_type_alias_, " = ", indices_type_, ";\n"); + } + if (usage_ & ShaderUsage::UseElementTypeAlias) { + SS_APPEND(ss, "alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); + } + + // Need shape and strides when (not use uniform) and (use shape and stride is enabled) + if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { + SS_APPEND(ss, "const ", shape, " = ", IndicesType(), "("); + + bool first = true; + for (auto dim : dims_.GetDims()) { + if (!first) { + ss << ","; + } + + ss << dim; + first = false; + } + ss << ");\n"; + + if (rank_ > 1) { + SS_APPEND(ss, "const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + first = true; + for (int i = 1; i < rank_; i++) { + if (!first) { + ss << ","; + } + ss << dims_.SizeFromDimension(i); + first = false; + } + ss << ");\n"; + } + } + + // Implementation of "fn o2i_{name}" + if (usage_ & ShaderUsage::UseOffsetToIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS_APPEND(ss, " var indices: ", IndicesType(), ";\n"); + SS_APPEND(ss, " var current = offset;\n"); + for (int i = 0; i < rank_ - 1; i++) { + auto current_stride = GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, " let dim", i, " = current / ", current_stride, ";\n"); + SS_APPEND(ss, " let rest", i, " = current % ", current_stride, ";\n"); + SS_APPEND(ss, " indices[", i, "] = dim", i, ";\n"); + SS_APPEND(ss, " current = rest", i, ";\n"); + } + SS_APPEND(ss, " indices[", rank_ - 1, "] = current;\n"); + SS_APPEND(ss, " return indices;\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn i2o_{name}" + if (usage_ & ShaderUsage::UseIndicesToOffset) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + SS_APPEND(ss, "indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); + } + SS_APPEND(ss, "indices[", rank_ - 1, "];\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn {res_name}_bi2o_{name}" + if (usage_ & ShaderUsage::UseBroadcastedIndicesToOffset) { + if (rank_ > 0) { + for (const auto& broadcasted_result_ptr : broadcasted_to_) { + const auto& broadcasted_result = *broadcasted_result_ptr; + SS_APPEND(ss, "fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + if (rank_ == 1) { + SS_APPEND(ss, " return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + } else { + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS_APPEND(ss, broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + } + SS_APPEND(ss, "}\n"); + } + } + } +} + +void ShaderVariableHelper::Impl(std::ostream& ss) const { + ShaderIndicesHelper::Impl(ss); + + // Implementation of "fn set_{name}" + if (usage_ & ShaderUsage::UseSet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " set_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ", value);\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn set_{name}_by_indices" + if (usage_ & ShaderUsage::UseSetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}" + if (usage_ & ShaderUsage::UseGet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return get_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ");\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}_by_indices" + if (usage_ & ShaderUsage::UseGetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "}\n"); + } + } +} + +std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { + SS(ss, kStringInitialSizeGetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << "vec4(bool(" + << name_ << "[" << offset << "] & 0xFFu), bool(" + << name_ << "[" << offset << "] & 0xFF00u), bool(" + << name_ << "[" << offset << "] & 0xFF0000u), bool(" + << name_ << "[" << offset << "] & 0xFF000000u))"; + break; + default: + ss << name_ << "[" << offset << "]"; + } + + return SS_GET(ss); +} + +std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { + SS(ss, kStringInitialSizeSetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; + break; + default: + ss << name_ << "[" << offset << "]=" << value << ";"; + } + + return SS_GET(ss); +} + +std::string_view ShaderVariableHelper::StorageType() const { + return STORAGE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ValueType() const { + return (usage_ & ShaderUsage::UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ElementType() const { + return (usage_ & ShaderUsage::UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; +} + +std::string_view ShaderIndicesHelper::IndicesType() const { + return (usage_ & ShaderUsage::UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; +} +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h new file mode 100644 index 0000000000000..4c87bc9158890 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +template || std::is_same_v>> +std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { + // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. + if (var.rfind("uniforms.", 0) == 0) { + if (rank > 4) { + if constexpr (std::is_integral_v) { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + } else { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + } + } else { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + } else { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); + } + } + } + } + + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; +} + +struct ShaderUsage { + enum : uint32_t { + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseShapeAndStride = 16, // use shape and stride for the variable + UseOffsetToIndices = 32, // use implementation of fn o2i_{name} + UseIndicesToOffset = 64, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 128, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 256, // use implementation of fn set_{name} + UseSetByIndices = 512, // use implementation of fn set_{name}_by_indices + UseGet = 1024, // use implementation of fn get_{name} + UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices + UseUniform = 32768, // use uniform for shape and stride + } usage; + + ShaderUsage(decltype(usage) usage) : usage{usage} {} + ShaderUsage(uint32_t usage) : usage{usage} {} + + explicit operator bool() { + return usage != None; + } +}; + +// A helper class to make it easier to generate shader code related to indices calculation. +class ShaderIndicesHelper { + public: + ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderIndicesHelper(ShaderIndicesHelper&&) = default; + ShaderIndicesHelper& operator=(ShaderIndicesHelper&&) = default; + + // get the number of components of the variable. + inline int NumComponents() const { return num_components_; } + + // get the rank of the indices. + inline int Rank() const; + + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. + // \param offset: a WGSL expression (u32) representing the offset. + inline std::string OffsetToIndices(std::string_view offset_expr) const; + + // create a WGSL expression (u32) for getting offset from indices. + // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. + inline std::string IndicesToOffset(std::string_view indices_expr) const; + + // create a WGSL expression (u32) for getting original offset from broadcasted indices. + // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. + // \param broadcasted_result: the broadcasted result variable. + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const; + + // create a WGSL expression ({varname}_indices_t) as an indices literal + // \param init: a list of indices values. + template + inline std::string Indices(TIndices&&... indices_args) const; + + // create a WGSL statement for setting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to set. + // \param value: the value (u32) to set. + template + inline std::string IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const; + + // create a WGSL expression (u32) for getting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to get. + template + inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; + + protected: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); + + void Impl(std::ostream& ss) const; + + std::string_view IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; // for variable + int num_components_; // for variable + int rank_; + TensorShape dims_; + + mutable ShaderUsage usage_; + // the pointers stored here are owned by the ShaderHelper instance that also owns this ShaderIndicesHelper instance. + // these instances are kept valid during the lifetime of the ShaderHelper instance. + mutable std::set broadcasted_to_; + + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + + friend class ShaderHelper; +}; + +// A helper class to make it easier to generate shader code related to a variable setting/getting and its indices calculation. +class ShaderVariableHelper : public ShaderIndicesHelper { + public: + ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderVariableHelper(ShaderVariableHelper&&) = default; + ShaderVariableHelper& operator=(ShaderVariableHelper&&) = default; + + // create a WGSL statement for setting data at the given indices. + // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). + template + inline std::string Set(TIndicesAndValue&&... args) const; + + // create a WGSL statement for setting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param value: the value ({varname}_value_t) to set. + inline std::string SetByIndices(std::string_view indices_var, std::string_view value) const; + + // create a WGSL statement for setting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + // \param value: the value ({varname}_value_t) to set. + template + inline std::string SetByOffset(TOffset&& offset, TValue&& value) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices: a list of indices values (u32). + template + inline std::string Get(TIndices&&... indices) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + inline std::string GetByIndices(std::string_view indices_var) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + template + inline std::string GetByOffset(TOffset&& offset) const; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); + + void Impl(std::ostream& ss) const; + + std::string GetByOffsetImpl(std::string_view offset) const; + std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string_view ElementType() const; + + friend class ShaderHelper; +}; + +inline ShaderUsage operator|(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage | (uint32_t)b.usage; +} +inline ShaderUsage operator&(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage & (uint32_t)b.usage; +} +inline ShaderUsage& operator|=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage |= (uint32_t)b.usage; + return a; +} +inline ShaderUsage& operator&=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage &= (uint32_t)b.usage; + return a; +} + +namespace detail { +template >> +std::string pass_as_string(T&& v) { + return std::to_string(std::forward(v)); +} +template +std::string_view pass_as_string(std::string_view sv) { + return sv; +} +template +std::string pass_as_string(T&& v) { + return std::forward(v); +} +} // namespace detail + +inline int ShaderIndicesHelper::Rank() const { + // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_; +} + +inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { + usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{offset_expr} + : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); +} + +inline std::string ShaderIndicesHelper::IndicesToOffset(std::string_view indices_expr) const { + usage_ |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_expr} + : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); +} + +inline std::string ShaderIndicesHelper::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const { + ORT_ENFORCE(broadcasted_result.num_components_ == -1 || + num_components_ == -1 || + broadcasted_result.num_components_ == num_components_, + "number of components should be the same for 2 variables to calculate"); + usage_ |= ShaderUsage::UseBroadcastedIndicesToOffset | ShaderUsage::UseShapeAndStride; + broadcasted_to_.insert(&broadcasted_result); + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); +} + +template +inline std::string ShaderIndicesHelper::Indices(TIndices&&... indices_args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(IndicesType(), "(", + absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), + ')'); +} + +template +inline std::string ShaderIndicesHelper::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') + : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); +} + +template +inline std::string ShaderIndicesHelper::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_var} + : GetElementAt(indices_var, idx_expr, rank_); +} + +template +inline std::string ShaderVariableHelper::SetByOffset(TOffset&& offset, TValue&& value) const { + return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); +} + +template +inline std::string ShaderVariableHelper::Set(TIndicesAndValue&&... args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); + if constexpr (sizeof...(TIndicesAndValue) == 1) { + return SetByOffset("0", std::forward(args)...); + } else if constexpr (sizeof...(TIndicesAndValue) == 2) { + return SetByOffset(std::forward(args)...); + } else { + usage_ |= ShaderUsage::UseSet | ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), + ");"); + } +} + +inline std::string ShaderVariableHelper::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return SetByOffset(indices_var, value); + } else { + usage_ |= ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); + } +} + +template +inline std::string ShaderVariableHelper::GetByOffset(TOffset&& offset) const { + return GetByOffsetImpl(detail::pass_as_string(offset)); +} + +template +inline std::string ShaderVariableHelper::Get(TIndices&&... indices) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); + if constexpr (sizeof...(TIndices) == 0) { + return GetByOffset("0"); + } else if constexpr (sizeof...(TIndices) == 1) { + return GetByOffset(std::forward(indices)...); + } else { + usage_ |= ShaderUsage::UseGet | ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), + ')'); + } +} + +inline std::string ShaderVariableHelper::GetByIndices(std::string_view indices_var) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return GetByOffset(indices_var); + } else { + usage_ |= ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h new file mode 100644 index 0000000000000..7821d9c49a171 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/string_utils.h" + +// macro "SS" - declare an ostream variable and its string buffer +#define SS(ss, reserve_size) \ + std::string ss##_str; \ + ss##_str.reserve(reserve_size); \ + ::onnxruntime::webgpu::OStringStream ss(&ss##_str) + +// macro "SS_GET" - get the string from the ostream +#define SS_GET(ss) ss##_str + +// macro "SS_APPEND" - use function call style to append to the ostream +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h new file mode 100644 index 0000000000000..e6d7097ad6182 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/make_string.h" +#include + +namespace onnxruntime { +namespace webgpu { + +constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeShaderSourceCode = 2048; +#ifndef NDEBUG +constexpr const size_t kStringInitialSizeCacheKey = 512; +#else +constexpr const size_t kStringInitialSizeCacheKey = 256; +#endif + +using OStringStream = absl::strings_internal::OStringStream; + +namespace detail { +inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept { +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept { + OStringStreamAppendImpl(ss, t); + OStringStreamAppendImpl(ss, args...); +} + +template +inline void OStringStreamAppend(std::ostream& ss, const Args&... args) { + return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t(args)...); +} + +} // namespace detail + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc new file mode 100644 index 0000000000000..8b5bede34e6d0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/tensor/cast.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +const std::vector& CastOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 6, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 9, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_KERNEL_EX( + Cast, + kOnnxDomain, + 19, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); + +Status Cast::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + uint32_t vec_size = gsl::narrow((size + 3) / 4); + + CastProgram program{to_}; + program + .AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .CacheHint(std::to_string(to_)); + return context.RunProgram(program); +} + +Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("x", ShaderUsage::UseUniform); + const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform); + std::string expression; + switch (to_) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + expression = "vec4(a)"; + break; + default: + ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); + } + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h new file mode 100644 index 0000000000000..ef5c4d5d0dabe --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class CastProgram final : public Program { + public: + CastProgram(int32_t to) : Program{"Cast"}, to_{to} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int32_t to_; +}; + +class Cast final : public WebGpuKernel { + public: + Cast(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t to; + Status status = info.GetAttr("to", &to); + ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); + to_ = gsl::narrow(to); + + // ignore attribute 'saturate' as float8 is not supported in WebGPU + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int32_t to_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc new file mode 100644 index 0000000000000..c708f24dcc330 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/tensor/concat.h" + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) +WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) +WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) +WEBGPU_CONCAT_KERNEL(13) + +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + os << " }\n" + "}\n"; +} + +Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); + const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + << " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; + return Status::OK(); +} + +Status Concat::ComputeInternal(ComputeContext& context) const { + int input_count = context.InputCount(); + InlinedTensorsVector input_tensors; + input_tensors.reserve(input_count); + for (int i = 0; i < input_count; ++i) { + input_tensors.push_back(context.Input(i)); + } + + Prepare prepare; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + if (prepare.output_num_elements == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + + ConcatProgram program{prepare.axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + auto axis_size = input.tensor->Shape()[prepare.axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); + + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", + input_count, ", output=1) exceeds the limit (", + context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + } + + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h new file mode 100644 index 0000000000000..0f6e6dd327e33 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class ConcatProgram final : public Program { + public: + ConcatProgram(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class Concat final : public WebGpuKernel, public ConcatBase { + public: + Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc new file mode 100644 index 0000000000000..809616660aa9e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" + +#include "core/providers/webgpu/tensor/expand.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (input.NumComponents() != output.NumComponents()) { + const auto& output_indices = shader.AddIndices("output_indices"); + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n " + << " let value = vec4(" << input.GetByOffset("input_offset") << ");\n" + << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } + return Status::OK(); +} + +Status Expand::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const auto* input_shape_tensor = context.Input(1); + + auto output_dims = input_shape_tensor->DataAsSpan(); + TensorShape output_shape{}; + TensorShape input_shape = input_tensor->Shape(); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape)); + + auto* output_tensor = context.Output(0, output_shape); + const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + + ExpandProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {data_size}, + }); + if (components_i != components_o) { + program.AddIndices(output_shape); + } + return context.RunProgram(program); +} + +#define WEBGPU_EXPAND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +#define WEBGPU_EXPAND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h new file mode 100644 index 0000000000000..046520b479257 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class ExpandProgram final : public Program { + public: + ExpandProgram() : Program{"Expand"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); +}; + +class Expand final : public WebGpuKernel { + public: + Expand(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc new file mode 100644 index 0000000000000..11ded865b6be2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/flatten.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 1, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 9, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_KERNEL_EX( + Flatten, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.h b/onnxruntime/core/providers/webgpu/tensor/flatten.h new file mode 100644 index 0000000000000..5fc49a844b404 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/flatten.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Flatten final : public OpKernel { + public: + explicit Flatten(const OpKernelInfo& info) : OpKernel{info} { + axis_ = info.GetAttrOrDefault("axis", 1); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* input_tensor = context->Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int64_t input_rank = input_shape.NumDimensions(); + + // Handle negative axis + int64_t axis = axis_; + if (axis < 0) { + axis += input_rank; + } + + if (axis > input_rank) { + return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank"); + } + + int64_t first_dim = 1; + for (int64_t i = 0; i < axis; i++) { + first_dim *= input_shape[i]; + } + + int64_t second_dim = 1; + for (int64_t i = axis; i < input_rank; i++) { + second_dim *= input_shape[i]; + } + + TensorShape output_shape({first_dim, second_dim}); + Tensor* output_tensor = context->Output(0, output_shape); + + const void* source = input_tensor->DataRaw(); + void* target = output_tensor->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor)); + } + + return Status::OK(); + } + + private: + int64_t axis_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc new file mode 100644 index 0000000000000..9f6e5f2420d86 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/gather.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var indices_indices = input_indices_indices_t(0);\n"; + for (int i = 0; i < indices.Rank(); i++) { + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; + for (int i = 0, j = 0; i < data.Rank(); i++) { + if (static_cast(i) == axis_) { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); + + return Status::OK(); +} + +Status Gather::ComputeInternal(ComputeContext& context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + uint32_t data_size = gsl::narrow(p.output_tensor->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + uint32_t axis = static_cast(p.axis); + GatherProgram program{axis}; + program + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(axis)) + .AddUniformVariables({{data_size}}); + return context.RunProgram(program); +} + +#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.h b/onnxruntime/core/providers/webgpu/tensor/gather.h new file mode 100644 index 0000000000000..bebe13519ce43 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/gatherbase.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherProgram final : public Program { + public: + GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t axis_; +}; + +class Gather final : public WebGpuKernel, public GatherBase { + public: + Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {} + + protected: + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc new file mode 100644 index 0000000000000..00d8caf2624a9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/gather_elements.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherElements, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + GatherElements); + +ONNX_OPERATOR_KERNEL_EX( + GatherElements, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + GatherElements); + +Status GatherElementsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform); + const ShaderVariableHelper& indices = shader.AddInput("indices", ShaderUsage::UseUniform); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var idx = " << indices.GetByOffset("global_idx") << ";\n" + << "if (idx < 0) {\n" + << " idx = idx + uniforms.axis_dim_limit;\n" + << "}\n" + << "var input_indices = output_indices;\n" + << input.IndicesSet("input_indices", "uniforms.axis", "u32(idx)") << ";\n" + << "let value = " << input.GetByIndices("input_indices") << ";\n" + << output.SetByOffset("global_idx", "value") << ";\n"; + + return Status::OK(); +} + +Status GatherElements::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int64_t input_rank = input_shape.NumDimensions(); + + const auto* indices_tensor = context.Input(1); + const TensorShape& indices_shape = indices_tensor->Shape(); + + // Handle negative axis + int64_t axis = axis_; + if (axis < 0) { + axis += input_rank; + } + + auto axis_dim_limit = input_shape[axis]; + + auto output_dims = indices_shape.AsShapeVector(); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + GatherElementsProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddInputs({{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(axis_dim_limit)}, + {static_cast(axis)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/gather_elements.h b/onnxruntime/core/providers/webgpu/tensor/gather_elements.h new file mode 100644 index 0000000000000..f70bbda84c933 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather_elements.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherElementsProgram final : public Program { + public: + GatherElementsProgram() : Program{"GatherElements"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"axis_dim_limit", ProgramUniformVariableDataType::Int32}, + {"axis", ProgramUniformVariableDataType::Int32}); +}; + +class GatherElements final : public WebGpuKernel { + public: + GatherElements(const OpKernelInfo& info) : WebGpuKernel(info) { + axis_ = info.GetAttrOrDefault("axis", 0); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc new file mode 100644 index 0000000000000..9ede015a0c99c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/reshape.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 13, 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 5, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.h b/onnxruntime/core/providers/webgpu/tensor/reshape.h new file mode 100644 index 0000000000000..4629598d068f7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/cpu/tensor/reshape_helper.h" + +namespace onnxruntime { +namespace webgpu { + +class Reshape final : public OpKernel { + public: + Reshape(const OpKernelInfo& info) + : OpKernel{info}, + allow_zero_(info.GetAttrOrDefault("allowzero", static_cast(0)) == 1) { + } + + Status Compute(OpKernelContext* context) const override { + // Copy the second input tensor into the shape vector + const Tensor* shapeTensor = context->Input(1); + if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (shapeTensor->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); + } + auto data_span = shapeTensor->template DataAsSpan(); + TensorShapeVector shape(data_span.begin(), data_span.end()); + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + ReshapeHelper helper(X_shape, shape, allow_zero_); + + Tensor* Y = context->Output(0, TensorShape(shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } + + private: + bool allow_zero_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc new file mode 100644 index 0000000000000..b211d48dab1c9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 15, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc new file mode 100644 index 0000000000000..136a1ba9776a0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/squeeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.h b/onnxruntime/core/providers/webgpu/tensor/squeeze.h new file mode 100644 index 0000000000000..bc80cb86d5e8e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Squeeze final : public OpKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : OpKernel{info}, SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc new file mode 100644 index 0000000000000..841c36724df30 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/tile.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t;\n"; + for (auto i = 0; i < input.Rank(); i++) { + std::string input_dim_i = absl::StrCat("input_dim_", i); + std::string input_dim_value = absl::StrCat("input_dim_", i, "_value"); + shader.MainFunctionBody() << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n" + << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n" + << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + } + + shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices")); + + return Status::OK(); +} + +Status Tile::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + + const auto* repeats_tensor = context.Input(1); + const auto* repeats_data = repeats_tensor->Data(); + std::vector repeats; + + for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { + repeats.push_back(static_cast(repeats_data[i])); + } + + auto output_dims = input_shape.AsShapeVector(); + for (size_t axis = 0; axis < input_rank; axis++) { + output_dims[axis] *= repeats[axis]; + } + + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + TileProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {repeats}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.h b/onnxruntime/core/providers/webgpu/tensor/tile.h new file mode 100644 index 0000000000000..9b6ab420b3252 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TileProgram final : public Program { + public: + TileProgram() : Program{"Tile"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"repeats", ProgramUniformVariableDataType::Uint32}); +}; + +class Tile final : public WebGpuKernel { + public: + Tile(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc new file mode 100644 index 0000000000000..c40ec43dd0009 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + +Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AdditionalImplementation() << "var tile : array, tile_size>;\n"; + shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + << " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n" + << " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + << " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n" + << " }"; + } else { + shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; + for (size_t i = 0; i < perm_.size(); ++i) { + shader.AdditionalImplementation() << " a[" << perm_[i] << "] = i[" << i << "];\n"; + } + shader.AdditionalImplementation() << " return a;\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices);\n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); + } + return Status::OK(); +} + +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + + program + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + .AddUniformVariables({ + {static_cast(output_size)}, + }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h new file mode 100644 index 0000000000000..7cf5c1fe0865d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + +class TransposeProgram final : public Program { + public: + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); + + private: + InlinedVector perm_; + const bool use_shared_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc new file mode 100644 index 0000000000000..4bcef4fd79296 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/unsqueeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h new file mode 100644 index 0000000000000..0ae9d50f6d4e7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Unsqueeze final : public OpKernel, public UnsqueezeBase { + public: + explicit Unsqueeze(const OpKernelInfo& info) : OpKernel{info}, UnsqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); + auto data_span = axes_tensor->template DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc new file mode 100644 index 0000000000000..524dd07d5b710 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/where.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +// Compute where operator output shape based upon three way broad-casting. +Status ComputeOutputShape(const TensorShape& cond_shape, + const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) { + size_t cond_rank = cond_shape.NumDimensions(); + size_t x_rank = x_shape.NumDimensions(); + size_t y_rank = y_shape.NumDimensions(); + size_t output_rank = std::max(std::max(cond_rank, x_rank), y_rank); + + std::vector output_dims(output_rank, 0); + for (size_t i = 0; i < output_rank; ++i) { + int64_t cond_dim = 1; + if (i < cond_rank) + cond_dim = cond_shape[cond_rank - 1 - i]; + + int64_t x_dim = 1; + if (i < x_rank) + x_dim = x_shape[x_rank - 1 - i]; + + int64_t y_dim = 1; + if (i < y_rank) + y_dim = y_shape[y_rank - 1 - i]; + + int64_t output_dim = std::max({cond_dim, x_dim, y_dim}); + // special case to handle a dim of 0 which can be broadcast with a 1 + if (output_dim == 1) + output_dim = std::min({cond_dim, x_dim, y_dim}); + + const auto node_name = "Where"; + if (cond_dim != output_dim && cond_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": condition operand cannot broadcast on dim ", cond_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (x_dim != output_dim && x_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": X operand cannot broadcast on dim ", x_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (y_dim != output_dim && y_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": Y operand cannot broadcast on dim ", y_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + output_dims[output_rank - 1 - i] = output_dim; + } + + output_shape = TensorShape(output_dims); + return Status::OK(); +} + +Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform); + const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform); + const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + const auto expression = [](std::string_view a, std::string_view b, std::string_view c) -> auto { + return absl::StrCat("select(", b, ", ", a, ", ", c, ")"); + }; + + if (!is_broadcast_) { + shader.MainFunctionBody() << output.SetByOffset( + "global_idx", + expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); + + } else { + const auto& c_indices = shader.AddIndices("c_indices"); + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + + const auto single_assignment = + [&expression, &shader, &output_indices, &a_indices, &b_indices, &c_indices]( + std::string_view rest_str, const std::string& x, std::string_view type_cast = "") + -> void { + const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; + const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" + << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let index_a" << x << " = offset_a" << x << " / 4;\n" + << "let index_b" << x << " = offset_b" << x << " / 4;\n" + << "let index_c" << x << " = offset_c" << x << " / 4;\n" + << "let component_a" << x << " = offset_a" << x << " % 4;\n" + << "let component_b" << x << " = offset_b" << x << " % 4;\n" + << "let component_c" << x << " = offset_c" << x << " % 4;\n" + << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; + }; + + if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { + shader.MainFunctionBody() << "var data = vec4(0);\n"; + single_assignment("data", "0", "u32"); + single_assignment("data", "1", "u32"); + single_assignment("data", "2", "u32"); + single_assignment("data", "3", "u32"); + shader.MainFunctionBody() << "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + } else { + single_assignment("output_data[global_idx]", "0"); + single_assignment("output_data[global_idx]", "1"); + single_assignment("output_data[global_idx]", "2"); + single_assignment("output_data[global_idx]", "3"); + } + } + + return Status::OK(); +} + +Status Where::ComputeInternal(ComputeContext& context) const { + const auto* cond_tensor = context.Input(0); + const auto* x_tensor = context.Input(1); + const auto* y_tensor = context.Input(2); + const auto& cond_shape = cond_tensor->Shape(); + const auto& x_shape = x_tensor->Shape(); + const auto& y_shape = y_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); + auto* output_tensor = context.Output(0, output_shape); + constexpr int component = 4; + uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + const auto is_broadcast = !(x_shape == y_shape && + y_shape == cond_shape); + WhereProgram program{is_broadcast}; + program + .CacheHint(is_broadcast) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (is_broadcast) { + program + .AddIndices(cond_shape) + .AddIndices(x_shape) + .AddIndices(y_shape) + .AddIndices(output_tensor->Shape()); + } + return context.RunProgram(program); +} + +namespace { +const std::vector& WhereOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Where, + kOnnxDomain, + 9, 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +ONNX_OPERATOR_KERNEL_EX( + Where, + kOnnxDomain, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.h b/onnxruntime/core/providers/webgpu/tensor/where.h new file mode 100644 index 0000000000000..e46b24e9ba2e5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class WhereProgram final : public Program { + public: + WhereProgram(bool is_broadcast) : Program{"Where"}, is_broadcast_{is_broadcast} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool is_broadcast_; +}; + +class Where final : public WebGpuKernel { + public: + Where(const OpKernelInfo& info) : WebGpuKernel{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc new file mode 100644 index 0000000000000..d66c2a79d28a8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -0,0 +1,689 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "dawn/dawn_proc.h" +#if !defined(USE_EXTERNAL_DAWN) +#include "dawn/native/DawnNative.h" +#endif + +#include "core/common/common.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/program_cache_key.h" +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) { + std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() { + // Initialization.Step.1 - Create wgpu::Instance + if (instance_ == nullptr) { + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); +#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); +#else +#if !defined(USE_EXTERNAL_DAWN) + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } +#else + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); +#endif + dawnProcSetProcs(dawn_procs); +#endif + + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); + } + + // Initialization.Step.2 - Create wgpu::Adapter + if (adapter_ == nullptr) { + wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; + req_adapter_options.nextInChain = &adapter_toggles_desc; + req_adapter_options.backendType = static_cast(webgpu_ep_info.backend_type); + req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; + + auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); + adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter( + &req_adapter_options, + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) { + ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message}); + *ptr = adapter; + }, + &adapter_), + UINT64_MAX)); + ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); + } + + // Initialization.Step.3 - Create wgpu::Device + if (device_ == nullptr) { + wgpu::DeviceDescriptor device_desc = {}; + wgpu::DawnTogglesDescriptor device_toggles_desc = {}; + device_desc.nextInChain = &device_toggles_desc; + + auto enabled_device_toggles = GetEnabledDeviceToggles(); + device_toggles_desc.enabledToggleCount = enabled_device_toggles.size(); + device_toggles_desc.enabledToggles = enabled_device_toggles.data(); + + auto disabled_device_toggles = GetDisabledDeviceToggles(); + device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); + device_toggles_desc.disabledToggles = disabled_device_toggles.data(); + + std::vector required_features = GetAvailableRequiredFeatures(adapter_); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); + } + wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_); + device_desc.requiredLimits = &required_limits; + + // TODO: revise temporary error handling + device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { + LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; + }); + // TODO: revise temporary device lost handling + device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, const char* message) { + // cannot use ORT logger because it may be already destroyed + std::cerr << "WebGPU device lost (" << int(reason) << "): " << message; + }); + + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice( + &device_desc, + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) { + ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message}); + *ptr = device; + }, + &device_), + UINT64_MAX)); + ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); + } + + // cache adapter info + ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); + // cache device limits + wgpu::SupportedLimits device_supported_limits; + ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); + device_limits_ = device_supported_limits.limits; + + // create buffer manager + buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + + // create program manager + program_mgr_ = std::make_unique(Device(), DeviceLimits()); + + // set query type + if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { + query_type_ = TimestampQueryType::InsidePasses; + } else if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) { + query_type_ = TimestampQueryType::AtPasses; + } else { + query_type_ = TimestampQueryType::None; + } + }); +} + +Status WebGpuContext::Wait(wgpu::Future f) { + auto status = instance_.WaitAny(f, UINT64_MAX); + if (status == wgpu::WaitStatus::Success) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); +} + +Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { + const auto& inputs = program.Inputs(); + const auto& outputs = program.Outputs(); + + if (outputs.size() == 0) { + return Status::OK(); + } + + if (ValidationMode() >= ValidationMode::Basic) { + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All outputs must be tensors on WebGPU buffers."); + } + + const ProgramMetadata& metadata = program.Metadata(); + + // validate program metadata + if (ValidationMode() >= ValidationMode::Basic) { + const auto& [constants, overridable_constants, uniform_variables] = metadata; + + // check overridable constants + ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(), + "Size of overridable constants mismatch in program \"", program.Name(), + "\", Expected: ", overridable_constants.size(), + ", Actual: ", program.OverridableConstants().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } + } + + // check uniform variables + ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(), + "Size of uniform_value variables mismatch in program \"", program.Name(), + "\", Expected: ", uniform_variables.size(), + ", Actual: ", program.UniformVariables().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } + } + } + + uint32_t x = program.DispatchGroupSizeX(); + uint32_t y = program.DispatchGroupSizeY(); + uint32_t z = program.DispatchGroupSizeZ(); + ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z)); + + bool is_1d_dispatch = (y == 1 && z == 1); + + auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + + if (is_profiling_) { + PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), + program.Name(), + key, + inputs, + outputs); + pending_kernels_.emplace_back(std::move(pending_kernel_info)); + } + + LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; + + const auto* program_artifact = program_mgr_->Get(key); + if (program_artifact == nullptr) { + wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniform_ranks; + auto status = program_mgr_->Build(program, + metadata, +#ifndef NDEBUG // if debug build + key, +#endif + x, + y, + z, + compute_pipeline, + shape_uniform_ranks); + ORT_RETURN_IF_ERROR(status); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, + std::move(compute_pipeline), + std::move(shape_uniform_ranks)}); +#ifndef NDEBUG // if debug build + ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); +#endif + } + + // prepare shape uniforms for shader variables (if any) and user defined uniforms + std::vector shape_uniforms; + shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); + if (ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size() + program.Indices().size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), + ", output: ", outputs.size(), + ", indices: ", program.Indices().size(), ")"); + } + + auto append_shape_uniforms = [&shape_uniforms, program_artifact](size_t i, const TensorShape& shape) { + if (program_artifact->shape_uniform_ranks[i] > 0) { + size_t expected_rank = static_cast(program_artifact->shape_uniform_ranks[i]); + ORT_RETURN_IF(expected_rank != shape.NumDimensions(), + "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", expected_rank, + ", Actual: ", shape.NumDimensions()); + + std::vector dims(expected_rank); + std::vector stride(expected_rank - 1); + for (size_t j = 0; j < expected_rank; ++j) { + dims[j] = gsl::narrow(shape[j]); + if (j < expected_rank - 1) { + stride[j] = gsl::narrow(shape.SizeFromDimension(j + 1)); + } + } + + shape_uniforms.emplace_back(gsl::make_span(dims)); + if (expected_rank > 1) { + shape_uniforms.emplace_back(gsl::make_span(stride)); + } + } + return Status::OK(); + }; + + for (size_t i = 0; i < inputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i, + inputs[i].use_override_shape ? inputs[i].override_shape : inputs[i].tensor->Shape())); + } + for (size_t i = 0; i < outputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size(), + outputs[i].use_override_shape ? outputs[i].override_shape : outputs[i].tensor->Shape())); + } + for (size_t i = 0; i < program.Indices().size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size() + outputs.size(), program.Indices()[i])); + } + + const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); + size_t current_offset = 0; + std::vector> uniform_and_offsets; + uniform_and_offsets.reserve(uniform_count); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] + : program.UniformVariables()[i - shape_uniforms.size()]; + size_t length = uniform.length; + if (length == 0) { // skip zero-length uniform + continue; + } + + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniform_and_offsets.emplace_back(uniform, current_offset); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + constexpr size_t max_alignment_of_field = 16; + const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; + + WGPUBuffer uniform_buffer = nullptr; + if (uniform_buffer_total_size > 0) { + std::vector uniform_data_buffer(uniform_buffer_total_size); + + for (auto const& [uniform, offset] : uniform_and_offsets) { + memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); + } + + uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + + WriteTimestamp(num_pending_dispatches_ * 2); + + uint32_t entry_index = 0; + std::vector bind_group_entries; + for (const auto& input : inputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); + } + for (const auto& output : outputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output.tensor->MutableDataRaw())}); + } + if (uniform_buffer) { + bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); + } + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + bind_group_desc.label = program_artifact->name.c_str(); + + auto bind_group = Device().CreateBindGroup(&bind_group_desc); + + // TODO support graph capture + + compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline); + compute_pass_encoder.SetBindGroup(0, bind_group); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + if (uniform_buffer) { + buffer_mgr_->Release(uniform_buffer); + } + + WriteTimestamp(num_pending_dispatches_ * 2 + 1); + + ++num_pending_dispatches_; + + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || + (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { + EndComputePass(); + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(); + num_pending_dispatches_ = 0; + } + + return Status::OK(); +} + +std::vector WebGpuContext::GetEnabledAdapterToggles() const { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetEnabledDeviceToggles() const { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "d3d_disable_ieee_strictness", + }; + return std::vector(ValidationMode() >= ValidationMode::WGPUOnly + ? std::begin(toggles) + 1 + : std::begin(toggles), + std::end(toggles)); +} + +std::vector WebGpuContext::GetDisabledDeviceToggles() const { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + "timestamp_quantization", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + +void WebGpuContext::WriteTimestamp(uint32_t query_index) { + if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) { + return; + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + compute_pass_encoder.WriteTimestamp(query_set_, query_index); +} + +void WebGpuContext::StartProfiling() { + if (query_type_ == TimestampQueryType::None) { + return; + } + + is_profiling_ = true; + + const uint32_t query_count = max_num_pending_dispatches_ * 2; + + if (!query_set_) { + // Create query set + wgpu::QuerySetDescriptor querySetDescriptor; + querySetDescriptor.count = query_count; + querySetDescriptor.type = wgpu::QueryType::Timestamp; + query_set_ = device_.CreateQuerySet(&querySetDescriptor); + } + + if (!query_resolve_buffer_) { + // Create resolve buffer + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc | + wgpu::BufferUsage::CopyDst; + query_resolve_buffer_ = device_.CreateBuffer(&bufferDescriptor); + } +} + +void WebGpuContext::CollectProfilingData(profiling::Events& events) { + if (!pending_queries_.empty()) { + for (const auto& pending_query : pending_queries_) { + const auto& pending_kernels = pending_query.kernels; + const auto& query_read_buffer = pending_query.query_buffer; + + ORT_ENFORCE(Wait(query_read_buffer.MapAsync(wgpu::MapMode::Read, + 0, + query_read_buffer.GetSize(), + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + auto mapped_data = static_cast(query_read_buffer.GetConstMappedRange()); + + for (size_t i = 0; i < pending_kernels.size(); i++) { + const PendingKernelInfo& pending_kernel_info = pending_kernels[i]; + const auto& inputs = pending_kernel_info.inputs; + const auto& outputs = pending_kernel_info.outputs; + + SS(shapes, 128); + for (size_t s = 0; s < inputs.size(); s++) { + const auto& input = inputs[s]; + shapes << "inputs[" << s << "] = " << input.override_shape.ToString() << " "; + } + for (size_t s = 0; s < outputs.size(); s++) { + const auto& output = outputs[s]; + shapes << "outputs[" << s << "] = " << output.override_shape.ToString() << " "; + } + + if (gpu_timestamp_offset_ == 0) { + gpu_timestamp_offset_ = mapped_data[i * 2]; + // TODO: apply CPU-GPU time offset so that timestamps are aligned + } + uint64_t start_time = mapped_data[i * 2] - gpu_timestamp_offset_; + uint64_t end_time = mapped_data[i * 2 + 1] - gpu_timestamp_offset_; + + const std::unordered_map& event_args = { + {"shapes", SS_GET(shapes)}, + {"cache_key", pending_kernel_info.cache_key}, + }; + + profiling::EventRecord event(profiling::API_EVENT, + -1, + -1, + pending_kernel_info.name, + static_cast(std::round(start_time / 1000.0)), + static_cast(std::round((end_time - start_time) / 1000.0)), + event_args); + events.emplace_back(std::move(event)); + } + + query_read_buffer.Unmap(); + query_read_buffer.Destroy(); + } + + pending_queries_.clear(); + } + + is_profiling_ = false; +} + +void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events, profiling::Events& cached_events) { + // This function is called when no active inference is ongoing. + ORT_ENFORCE(!is_profiling_, "Profiling is ongoing in an inference run."); + + if (query_type_ != TimestampQueryType::None) { + // No pending kernels or queries should be present at this point. They should have been collected in CollectProfilingData. + ORT_ENFORCE(pending_kernels_.empty() && pending_queries_.empty(), "Pending kernels or queries are not empty."); + + events.insert(events.end(), + std::make_move_iterator(cached_events.begin()), + std::make_move_iterator(cached_events.end())); + + cached_events.clear(); + } else { + LOGS_DEFAULT(WARNING) << "TimestampQuery is not supported in this device."; + } +} + +void WebGpuContext::Flush() { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + if (is_profiling_ && num_pending_dispatches_ > 0) { + uint32_t query_count = num_pending_dispatches_ * 2; + current_command_encoder_.ResolveQuerySet( + query_set_, + 0, + query_count, + query_resolve_buffer_, + 0); + + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst; + wgpu::Buffer query_read_buffer = device_.CreateBuffer(&bufferDescriptor); + + current_command_encoder_.CopyBufferToBuffer( + query_resolve_buffer_, + 0, + query_read_buffer, + 0, + query_count * sizeof(uint64_t)); + + pending_queries_.emplace_back(std::move(pending_kernels_), query_read_buffer); + pending_kernels_.clear(); + } + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + num_pending_dispatches_ = 0; +} + +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode) { + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + } else { + // for context ID > 0, user must provide custom WebGPU instance, adapter and device. + ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, + "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); + } + + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + if (it == contexts_.end()) { + GSL_SUPPRESS(r.11) + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); + it = contexts_.emplace(context_id, std::move(context)).first; + } else if (context_id != 0) { + ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); + } + return *it->second; +} + +WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + return *it->second; +} + +void WebGpuContextFactory::Cleanup() { + std::lock_guard lock(mutex_); + contexts_.clear(); +} + +void CleanupWebGpuContexts() { + WebGpuContextFactory::Cleanup(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h new file mode 100644 index 0000000000000..be05b06523b9c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class WebGpuContext; +class ComputeContext; +class ProgramBase; + +class WebGpuContextFactory { + public: + static WebGpuContext& CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode); + static WebGpuContext& GetContext(int context_id); + + static void Cleanup(); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; + +// Class WebGpuContext includes all necessary resources for the context. +class WebGpuContext final { + public: + void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table); + + Status Wait(wgpu::Future f); + + const wgpu::Adapter& Adapter() const { return adapter_; } + const wgpu::Device& Device() const { return device_; } + + const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } + const wgpu::Limits& DeviceLimits() const { return device_limits_; } + + const wgpu::CommandEncoder& GetCommandEncoder() { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + const wgpu::ComputePassEncoder& GetComputePassEncoder() { + if (!current_compute_pass_encoder_) { + auto& command_encoder = GetCommandEncoder(); + + wgpu::ComputePassDescriptor compute_pass_desc{}; + + if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { + wgpu::ComputePassTimestampWrites timestampWrites = { + query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1}; + compute_pass_desc.timestampWrites = ×tampWrites; + } + + current_compute_pass_encoder_ = command_encoder.BeginComputePass(&compute_pass_desc); + } + return current_compute_pass_encoder_; + } + + void EndComputePass() { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush(); + + webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + + inline webgpu::ValidationMode ValidationMode() const { + return validation_mode_; + } + + void StartProfiling(); + void CollectProfilingData(profiling::Events& events); + void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); + + Status Run(ComputeContext& context, const ProgramBase& program); + + private: + enum class TimestampQueryType { + None = 0, + InsidePasses, + AtPasses + }; + + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + + std::vector GetEnabledAdapterToggles() const; + std::vector GetEnabledDeviceToggles() const; + std::vector GetDisabledDeviceToggles() const; + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) const; + void WriteTimestamp(uint32_t query_index); + + struct PendingKernelInfo { + PendingKernelInfo(std::string_view kernel_name, + std::string_view program_name, + std::string_view cache_key, + const std::vector& inputs, + const std::vector& outputs) + : name{absl::StrJoin({kernel_name, program_name}, "_")}, cache_key{cache_key}, inputs{inputs}, outputs{outputs} {} + + PendingKernelInfo(PendingKernelInfo&&) = default; + PendingKernelInfo& operator=(PendingKernelInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo); + + std::string name; + std::string cache_key; + std::vector inputs; + std::vector outputs; + }; + + struct PendingQueryInfo { + PendingQueryInfo(std::vector&& kernels, wgpu::Buffer query_buffer) + : kernels{std::move(kernels)}, query_buffer{query_buffer} {} + + PendingQueryInfo(PendingQueryInfo&&) = default; + PendingQueryInfo& operator=(PendingQueryInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingQueryInfo); + + std::vector kernels; + wgpu::Buffer query_buffer; + }; + + friend class WebGpuContextFactory; + + std::once_flag init_flag_; + + wgpu::Instance instance_; + wgpu::Adapter adapter_; + wgpu::Device device_; + + webgpu::ValidationMode validation_mode_; + + wgpu::AdapterInfo adapter_info_; + wgpu::Limits device_limits_; + + wgpu::CommandEncoder current_command_encoder_; + wgpu::ComputePassEncoder current_compute_pass_encoder_; + + std::unique_ptr buffer_mgr_; + std::unique_ptr program_mgr_; + + uint32_t num_pending_dispatches_ = 0; + const uint32_t max_num_pending_dispatches_ = 16; + + // profiling + TimestampQueryType query_type_; + wgpu::QuerySet query_set_; + wgpu::Buffer query_resolve_buffer_; + + // info of kernels pending submission for a single batch + std::vector pending_kernels_; + // info of queries pending appending to profiling events + std::vector pending_queries_; + + uint64_t gpu_timestamp_offset_ = 0; + bool is_profiling_ = false; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 00ebdd5583d2e..295a8de31ed50 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,6 +3,9 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" +#ifdef __EMSCRIPTEN__ +#include +#endif #include #include #include @@ -13,6 +16,7 @@ #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #endif +#include "allocator.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" @@ -20,6 +24,10 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_profiler.h" + namespace onnxruntime { namespace webgpu { @@ -65,6 +73,330 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Log); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Log); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Atanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, Not); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 10, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Equal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, 18, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -72,6 +404,322 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + KERNEL_CREATE_INFO(13, Tanh), + KERNEL_CREATE_INFO(1, Not), + + KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + KERNEL_CREATE_INFO(19, Cast), + + // // activations + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(20, Gelu), + + // // binary - math + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + KERNEL_CREATE_INFO(19, Equal), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -93,8 +741,77 @@ std::unique_ptr RegisterKernels() { using namespace webgpu; -WebGpuExecutionProvider::WebGpuExecutionProvider() - : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)} {} +WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, + WebGpuContext& context, + WebGpuExecutionProviderInfo&& info) + : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + context_id_{context_id}, + context_{context}, + preferred_data_layout_{info.data_layout}, + force_cpu_node_names_{std::move(info.force_cpu_node_names)}, + enable_graph_capture_{info.enable_graph_capture} { +} + +std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { + return std::make_unique(context_); + }, + 0, false); + return std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; +} + +std::vector> WebGpuExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + + // TODO: currently this lookup is O(N). If the list becomes large we should optimize this. + if (std::find(force_cpu_node_names_.cbegin(), + force_cpu_node_names_.cend(), + node.Name()) != force_cpu_node_names_.cend()) { + LOGS(*GetLogger(), INFO) << "Force CPU execution for node: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; +} std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { static std::shared_ptr registry = webgpu::RegisterKernels(); @@ -102,7 +819,68 @@ std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() con return registry; } +std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { + return std::make_unique(context_); +} + WebGpuExecutionProvider::~WebGpuExecutionProvider() { } +std::unique_ptr WebGpuExecutionProvider::GetProfiler() { + auto profiler = std::make_unique(context_); + profiler_ = profiler.get(); + return profiler; +} + +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { + if (profiler_->Enabled()) { + context_.StartProfiling(); + } + + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + } + return Status::OK(); +} + +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + // is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + context_.Flush(); + + if (profiler_->Enabled()) { + context_.CollectProfilingData(profiler_->Events()); + } + + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool WebGpuExecutionProvider::IsGraphCaptured(int) const { + return is_graph_captured_; +} + +Status WebGpuExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); + ORT_ENFORCE(false); + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 537ecb9301f67..f9c43c6bfd7d0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/providers/providers.h" +struct pthreadpool; namespace onnxruntime { namespace webgpu { @@ -16,22 +17,80 @@ namespace webgpu { template KernelCreateInfo BuildKernelCreateInfo(); +class WebGpuContext; +enum class BufferCacheMode; +class WebGpuProfiler; } // namespace webgpu +struct WebGpuExecutionProviderInfo { + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + : data_layout{data_layout}, + enable_graph_capture{enable_graph_capture}, + backend_type{}, + storage_buffer_cache_mode{}, + uniform_buffer_cache_mode{}, + query_resolve_buffer_cache_mode{}, + default_buffer_cache_mode{} {} + WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default; + WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo); + + DataLayout data_layout; + bool enable_graph_capture; + int backend_type; + webgpu::BufferCacheMode storage_buffer_cache_mode; + webgpu::BufferCacheMode uniform_buffer_cache_mode; + webgpu::BufferCacheMode query_resolve_buffer_cache_mode; + webgpu::BufferCacheMode default_buffer_cache_mode; + std::vector force_cpu_node_names; +}; + class WebGpuExecutionProvider : public IExecutionProvider { public: - WebGpuExecutionProvider(); + WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info); ~WebGpuExecutionProvider() override; + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; - DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to // work, and concurrent run with async function may mess up the states and cause undefined behavior. bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; + + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + + // WebGPU EP reuses the Device ID as the key to get the WebGpuContext instance. + int GetDeviceId() const override { return context_id_; } + + std::unique_ptr GetProfiler() override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; + webgpu::WebGpuContext& context_; + webgpu::WebGpuProfiler* profiler_ = nullptr; + DataLayout preferred_data_layout_; + std::vector force_cpu_node_names_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h new file mode 100644 index 0000000000000..d7682e751d9e4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/compute_context.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +// ----------------------------------------------------------------------- +// Base class for WebGPU kernels +// ----------------------------------------------------------------------- +class WebGpuKernel : public OpKernel { + public: + explicit WebGpuKernel(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override { + ComputeContext context{*p_op_kernel_context}; + + context.PushErrorScope(); + Status s = ComputeInternal(context); + ORT_RETURN_IF_ERROR(context.PopErrorScope()); + + return s; + } + + virtual Status ComputeInternal(ComputeContext& context) const = 0; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.cc b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc new file mode 100644 index 0000000000000..ce973987e593a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_profiler.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +WebGpuProfiler::WebGpuProfiler(WebGpuContext& context) : context_{context} {} + +bool WebGpuProfiler::StartProfiling(TimePoint) { + enabled_ = true; + return true; +} + +void WebGpuProfiler::EndProfiling(TimePoint tp, onnxruntime::profiling::Events& events) { + context_.EndProfiling(tp, events, events_); + enabled_ = false; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.h b/onnxruntime/core/providers/webgpu/webgpu_profiler.h new file mode 100644 index 0000000000000..d826d295a3842 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/profiler_common.h" + +namespace onnxruntime { + +namespace webgpu { +class WebGpuContext; + +class WebGpuProfiler final : public onnxruntime::profiling::EpProfiler { + public: + WebGpuProfiler(WebGpuContext& context); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuProfiler); + ~WebGpuProfiler() {} + bool StartProfiling(TimePoint) override; + void EndProfiling(TimePoint, onnxruntime::profiling::Events&) override; + void Start(uint64_t) override { + } + void Stop(uint64_t) override { + } + inline bool Enabled() const { return enabled_; } + inline onnxruntime::profiling::Events& Events() { return events_; } + + private: + WebGpuContext& context_; + bool enabled_{false}; + onnxruntime::profiling::Events events_; // cached GPU events +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 1a1f1a438c750..6cfe9aac0b0e9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -4,21 +4,214 @@ #include #include "core/framework/error_code_helper.h" -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +#include "core/providers/webgpu/webgpu_provider_options.h" +using namespace onnxruntime::webgpu::options; namespace onnxruntime { struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory() {} + WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& webgpu_ep_info) + : context_id_{context_id}, context_{context}, info_{std::move(webgpu_ep_info)} { + } std::unique_ptr CreateProvider() override { - return std::make_unique(); + return std::make_unique(context_id_, context_, std::move(info_)); } + + private: + int context_id_; + webgpu::WebGpuContext& context_; + WebGpuExecutionProviderInfo info_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions&) { - return std::make_shared(); +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { + // + // STEP.1 - prepare WebGpuExecutionProviderInfo + // + WebGpuExecutionProviderInfo webgpu_ep_info{ + // preferred layout is NHWC by default + DataLayout::NHWC, + // graph capture feature is disabled by default + false, + }; + + std::string preferred_layout_str; + if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (preferred_layout_str == kPreferredLayout_NHWC) { + webgpu_ep_info.data_layout = DataLayout::NHWC; + } else if (preferred_layout_str == kPreferredLayout_NCHW) { + webgpu_ep_info.data_layout = DataLayout::NCHW; + } else { + ORT_THROW("Invalid preferred layout: ", preferred_layout_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + << preferred_layout_str << "\")"; + + std::string enable_graph_capture_str; + if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (enable_graph_capture_str == kEnableGraphCapture_ON) { + webgpu_ep_info.enable_graph_capture = true; + } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { + webgpu_ep_info.enable_graph_capture = false; + } else { + ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + + std::string backend_type_str; + if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { +#ifdef _WIN32 + // Setup Windows default backend type based on the build configuration +#if defined(onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); +#elif defined(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); +#endif +#endif + if (backend_type_str == kDawnBackendType_D3D12) { + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); + } else if (backend_type_str == kDawnBackendType_Vulkan) { + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); + } else { + ORT_THROW("Invalid Dawn backend type: ", backend_type_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << webgpu_ep_info.backend_type; + + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default_value; + } + }; + + webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; + + webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; + + webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; + + webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::Basic // for release build, enable basic validation by default +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + + // parse force CPU node names + // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. + // each line is a node name that will be forced to run on CPU. + std::string force_cpu_node_names_str; + if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { + std::vector force_cpu_node_names; + + // split the string by EOL (\n or \r\n) + std::istringstream ss(force_cpu_node_names_str); + std::string line; + while (std::getline(ss, line)) { + // skip empty lines + if (line.empty()) { + continue; + } + + force_cpu_node_names.push_back(line); + } + + webgpu_ep_info.force_cpu_node_names = std::move(force_cpu_node_names); + } + + // + // STEP.2 - prepare WebGpuContext + // + int context_id = 0; + std::string context_id_str; + if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + } + + size_t webgpu_instance = 0; + std::string webgpu_instance_str; + if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + } + + size_t webgpu_adapter = 0; + std::string webgpu_adapter_str; + if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); + } + + size_t webgpu_device = 0; + std::string webgpu_device_str; + if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + } + + size_t dawn_proc_table = 0; + std::string dawn_proc_table_str; + if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + } + + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device), + validation_mode); + context.Initialize(webgpu_ep_info, reinterpret_cast(dawn_proc_table)); + + return std::make_shared(context_id, context, std::move(webgpu_ep_info)); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 6257a85d45760..e0030a3ec2a11 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -8,6 +8,8 @@ #include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/webgpu_provider_options.h" + namespace onnxruntime { struct ConfigOptions; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h new file mode 100644 index 0000000000000..12bb4b32e6a35 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { +namespace options { + +// The following are the options that can be set in the WebGPU provider options. + +constexpr const char* kPreferredLayout = "WebGPU:preferredLayout"; +constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture"; + +constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable"; + +constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType"; + +constexpr const char* kDeviceId = "WebGPU:deviceId"; +constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance"; +constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter"; +constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice"; + +constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode"; +constexpr const char* kUniformBufferCacheMode = "WebGPU:uniformBufferCacheMode"; +constexpr const char* kQueryResolveBufferCacheMode = "WebGPU:queryResolveBufferCacheMode"; +constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode"; + +constexpr const char* kValidationMode = "WebGPU:validationMode"; + +constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames"; + +// The following are the possible values for the provider options. + +constexpr const char* kDawnBackendType_D3D12 = "D3D12"; +constexpr const char* kDawnBackendType_Vulkan = "Vulkan"; + +constexpr const char* kPreferredLayout_NCHW = "NCHW"; +constexpr const char* kPreferredLayout_NHWC = "NHWC"; + +constexpr const char* kEnableGraphCapture_ON = "1"; +constexpr const char* kEnableGraphCapture_OFF = "0"; + +constexpr const char* kBufferCacheMode_Disabled = "disabled"; +constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; +constexpr const char* kBufferCacheMode_Simple = "simple"; +constexpr const char* kBufferCacheMode_Bucket = "bucket"; + +constexpr const char* kValidationMode_Disabled = "disabled"; +constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly"; +constexpr const char* kValidationMode_basic = "basic"; +constexpr const char* kValidationMode_full = "full"; + +} // namespace options +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h new file mode 100644 index 0000000000000..ff66cd535399e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +using SupportedNumberTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t>; + +using SupportedFloats = + TypeList< + float, + MLFloat16>; + +inline const std::vector& WebGpuSupportedNumberTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +inline const std::vector& WebGpuSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index b90c7d76a6507..45a87960126cd 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,17 +69,17 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { - const auto& input_name = input.Name(); - const auto* shape_proto = input.Shape(); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input) { + const auto& node_arg_name = node_arg.Name(); + const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. - if (input_name.empty()) { + if (node_arg_name.empty()) { return true; } - // We do not support input with no shape. + // We do not support input/output with no shape. if (!shape_proto) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has not shape"; + LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape"; return false; } @@ -87,8 +87,11 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " - << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " - << input_name; + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; + return false; + } + if (dim.dim_value() == 0 && !allow_empty_input) { + LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; return false; } } @@ -102,13 +105,6 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; - - for (const auto* input : graph_viewer.GetInputs()) { - if (!IsInputSupported(*input, "graph", logger)) { - return supported_node_groups; - } - } - std::vector supported_node_group; const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); @@ -118,7 +114,6 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v bool supported = false; // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { - LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } @@ -183,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) return false; + return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits, + webnn_input_output_name, onnx_input_output_name, logger); +} + +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + if (wnn_limits[webnn_op_type].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now"; + return false; + } + if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter [" + << webnn_input_output_name << "]"; + return false; + } if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { - LOGS(logger, VERBOSE) << "[" << onnx_op_type - << "] " << onnx_input_output_name - << " type: [" << onnx_data_type - << "] is not supported for now"; + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" + << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; } - return true; } @@ -226,6 +238,12 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + desc.set("dataType", emscripten::val("int4")); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + desc.set("dataType", emscripten::val("uint4")); + return true; case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: desc.set("dataType", emscripten::val("uint8")); @@ -261,5 +279,67 @@ bool IsMLTensorSupported() { return is_supported; } +// Convert int8 to uint4/int4 (stored as uint8) +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type) { + uint8_t result = 0; + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + if (value < 0 || value > 15) { + ORT_THROW("Value cannot be safely converted to uint4."); + } + result |= (static_cast(value) << 4); + } else { + if (value < -8 || value > 7) { + ORT_THROW("Value cannot be safely converted to int4."); + } + result |= (value << 4); + } + + return result; +} + +// Convert float32 to float16 (stored as uint16) +uint16_t PackFloat32ToUint16AsFloat16(float value) { + uint32_t float32_bits; + + // Safely copy the float bits into an integer + std::memcpy(&float32_bits, &value, sizeof(float)); + + // Extract the sign, exponent, and mantissa from the float32 bits + uint32_t sign = (float32_bits >> 31) & 0x1; + uint32_t exponent = (float32_bits >> 23) & 0xFF; + uint32_t mantissa = float32_bits & 0x7FFFFF; + + // Shift the sign for float16 + uint16_t sign_float16 = sign << 15; + + // Handle special cases: Infinity and NaN + if (exponent == 255) { + return sign_float16 | (0x1F << 10) | (mantissa ? 0x200 : 0); + } + // Handle zero and subnormal numbers in float32 + if (exponent == 0) { + return sign_float16 | (mantissa >> 13); + } + + // Adjust the exponent for float16 (subtract bias difference: 127 - 15 = 112) + int exponent_float16 = exponent - 112; + + // Handle exponent overflow (larger than float16 can represent) + if (exponent_float16 >= 0x1F) { + return sign_float16 | (0x1F << 10); + } + // Handle exponent underflow (smaller than float16 can represent) + if (exponent_float16 <= 0) { + mantissa = (mantissa | 0x800000) >> (1 - exponent_float16); + return sign_float16 | (mantissa >> 13); + } + + // Adjust the mantissa by shifting it to fit float16 format (round to nearest even) + uint16_t mantissa_float16 = (mantissa + 0x1000) >> 13; + + // Combine sign, exponent, and mantissa into the final float16 representation + return sign_float16 | (exponent_float16 << 10) | mantissa_float16; +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index aecb1f7a03bb9..a06f46f1bdf0a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include "core/common/inlined_containers.h" #include @@ -36,6 +37,31 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type); // Collects all the initializer tensors in the subGraph and its ancestor graphs. InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); +inline std::vector convertAxesFromNCHWtoNHWC(const std::vector& axes) { + constexpr std::array nchw_to_nhwc = {0, 3, 1, 2}; + std::vector new_axes; + new_axes.reserve(axes.size()); + for (int64_t axis : axes) { + if (axis >= nchw_to_nhwc.size()) { + ORT_THROW("Invalid axis value: ", axis); + } + new_axes.push_back(nchw_to_nhwc[static_cast(axis)]); + } + return new_axes; +} + +inline std::vector HandleNegativeAxes(const std::vector& axes, size_t input_size) { + std::vector new_axes(axes.size()); + for (size_t i = 0; i < axes.size(); ++i) { + new_axes[i] = HandleNegativeAxis(axes[i], input_size); + } + return new_axes; +} + +inline std::vector GetResolvedAxes(const NodeAttrHelper& helper, size_t input_size) { + return HandleNegativeAxes(helper.Get("axes", std::vector{}), input_size); +} + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -56,7 +82,7 @@ inline std::string GetTensorName(const ConstPointerContainer index) ? std::string(input_defs[index]->Name()) : ""; } -inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { +inline std::vector GetVecUint32FromVecInt64(gsl::span int64_vec) { std::vector uint32_vec; uint32_vec.reserve(int64_vec.size()); std::transform(int64_vec.begin(), int64_vec.end(), @@ -144,7 +170,19 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va return true; } -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::string& name) { + if (name.empty() || !Contains(initializers, name)) { + return true; + } + + const auto& tensor = *initializers.at(name); + const auto dims = tensor.dims(); + // An empty tensor contains a 0 in the dimensions list. + return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); +} + +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input = false); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, @@ -155,6 +193,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v static const InlinedHashMap op_map = { {"Abs", "abs"}, {"Add", "add"}, + {"And", "logicalAnd"}, {"ArgMax", "argMax"}, {"ArgMin", "argMin"}, {"AveragePool", "averagePool2d"}, @@ -167,10 +206,12 @@ static const InlinedHashMap op_map = { {"ConvInteger", "conv2dInteger"}, {"ConvTranspose", "convTranspose2d"}, {"Cos", "cos"}, + {"CumSum", "cumulativeSum"}, {"Div", "div"}, {"DequantizeLinear", "dequantizeLinear"}, {"Dropout", "identity"}, {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, + {"Einsum", "matmul"}, {"Elu", "elu"}, {"Equal", "equal"}, {"Erf", "erf"}, @@ -179,6 +220,8 @@ static const InlinedHashMap op_map = { {"Flatten", "reshape"}, {"Floor", "floor"}, {"Gather", "gather"}, + {"GatherElements", "gatherElements"}, + {"GatherND", "gatherND"}, {"Gelu", "gelu"}, {"Gemm", "gemm"}, {"GlobalAveragePool", "averagePool2d"}, @@ -186,7 +229,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, - {"Gru", "gru"}, + {"GRU", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, @@ -198,6 +241,7 @@ static const InlinedHashMap op_map = { {"Log", "log"}, {"LpPool", "l2Pool2d"}, {"LSTM", "lstm"}, + {"LRN", "averagePool2d"}, {"MatMul", "matmul"}, {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, @@ -206,6 +250,7 @@ static const InlinedHashMap op_map = { {"Mul", "mul"}, {"Neg", "neg"}, {"Not", "logicalNot"}, + {"Or", "logicalOr"}, {"Pad", "pad"}, {"Pow", "pow"}, {"PRelu", "prelu"}, @@ -224,8 +269,12 @@ static const InlinedHashMap op_map = { {"Relu", "relu"}, {"Reshape", "reshape"}, {"Resize", "resample2d"}, + {"ScatterElements", "scatterElements"}, + {"ScatterND", "scatterND"}, {"Shape", "slice"}, {"Sigmoid", "sigmoid"}, + {"Sign", "sign"}, + {"SimplifiedLayerNormalization", "layerNormalization"}, {"Softplus", "softplus"}, {"Softsign", "softsign"}, {"Sin", "sin"}, @@ -242,6 +291,7 @@ static const InlinedHashMap op_map = { {"Trilu", "triangular"}, {"Unsqueeze", "reshape"}, {"Where", "where"}, + {"Xor", "logicalXor"}, }; inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder, @@ -267,6 +317,8 @@ inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_typ } static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, @@ -288,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, const std::string& webnn_input_output_name, const std::string& onnx_input_output_name, const logging::Logger& logger); +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, @@ -297,5 +356,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); bool IsMLTensorSupported(); +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type); +uint16_t PackFloat32ToUint16AsFloat16(float value); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 8da255a288f17..290d16a48dd83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -12,27 +12,6 @@ namespace onnxruntime { namespace webnn { - -// Shared functions. -bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, - const logging::Logger& logger) { - for (const auto* node_arg : node.InputDefs()) { - const auto& input_name(node_arg->Name()); - if (!Contains(initializers, input_name)) - continue; - - const auto& tensor = *initializers.at(input_name); - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS(logger, VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; - return true; - } - } - - return false; -} - // Add operator related. Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, @@ -42,8 +21,6 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& model_builder.GetOpSupportLimits(), logger), "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; return Status::OK(); } @@ -52,14 +29,10 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const emscripten::val& wnn_limits, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, wnn_limits, logger)) - return false; - - if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) + if (!HasSupportedInputs(initializers, node, wnn_limits, logger)) return false; - // We do not support external initializers for now. - if (HasExternalInitializer(initializers, node, logger)) + if (!HasSupportedOutputs(node, wnn_limits, logger)) return false; if (!HasSupportedOpSet(node, logger)) @@ -68,19 +41,19 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, +bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) { return false; } } - return HasSupportedInputsImpl(node, wnn_limits, logger); + return HasSupportedInputsImpl(initializers, node, wnn_limits, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, +bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. @@ -93,6 +66,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); } +bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); + for (const auto* output : node.OutputDefs()) { + if (!IsTensorShapeSupported(*output, node_name, logger)) { + return false; + } + } + + return HasSupportedOutputsImpl(node, wnn_limits, logger); +} + bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 584455f62cb4e..0a4367a71add4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; @@ -37,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; @@ -53,7 +56,10 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + + const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index af82a01b14de5..e14507e8f5aea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 3c4fc822f3d01..4b2f04bed0eb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -38,6 +38,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); std::string operand_type; switch (to_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + operand_type = "int4"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + operand_type = "uint4"; + break; case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: operand_type = "uint8"; @@ -80,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 48dd6f3beb020..bac528300e077 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -42,7 +42,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector inputs; for (const auto* input : input_defs) { - LOGS(logger, VERBOSE) << "input name " << input->Name(); inputs.push_back(model_builder.GetOperand(input->Name())); } @@ -56,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f03e5b90ff6db..81e688ea4f8ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -133,7 +133,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, const auto out_t = dims[0], in_t = dims[1], h_t = dims[2], w_t = dims[3]; std::vector dest_shape; - if (is_conv == 1) + if (is_conv) dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight @@ -265,7 +265,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); if (is_constant_weight) { - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); } } } @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + w_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("conv2dInteger", input, x_zero_point, filter, w_zero_point, options); @@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc new file mode 100644 index 0000000000000..be30c5520d62e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class CumSumOpBuilder : public BaseOpBuilder { + // Add operator related. + + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip axis. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + const auto input_rank = input_shape.size(); + + const auto& initializers = model_builder.GetInitializerTensors(); + const std::string axis_name = GetTensorName(input_defs, 1); + const auto axis_tensor = *initializers.at(axis_name); + emscripten::val axis = emscripten::val::undefined(); + ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value"); + int64_t webnn_axis = HandleNegativeAxis(axis.as(), input_rank); + + NodeAttrHelper helper(node); + const auto exclusive = helper.Get("exclusive", 0); + const auto reverse = helper.Get("reverse", 0); + + emscripten::val options = emscripten::val::object(); + options.set("exclusive", exclusive == 1); + options.set("reversed", reverse == 1); + options.set("label", node.Name()); + + emscripten::val output = emscripten::val::object(); + output = model_builder.GetBuilder().call("cumulativeSum", input, gsl::narrow(webnn_axis), + options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const std::string axis_name = GetTensorName(input_defs, 1); + // Inputs contain optional 'axis' input. + if (!Contains(initializers, axis_name)) { + LOGS(logger, VERBOSE) << "The axis must be a constant initializer."; + return false; + } + + return true; +} + +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 5434194a214ac..9bb930c63b009 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -59,22 +59,14 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector mask_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); std::vector dims = GetVecUint32FromVecInt64(mask_shape); - - emscripten::val desc = emscripten::val::object(); - desc.set("dataType", "uint8"); - desc.set("dimensions", emscripten::val::array(dims)); - desc.set("shape", emscripten::val::array(dims)); - const auto num_elements = narrow(Product(mask_shape)); - emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); - ones_buffer.call("fill", 1); - - emscripten::val mask_output = model_builder.GetBuilder().call("constant", desc, ones_buffer); + emscripten::val one_constant = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims); emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, // beacuse WebNN does not support a constant operand as output. - mask_output = model_builder.GetBuilder().call("identity", mask_output, options); + emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc new file mode 100644 index 0000000000000..ef713f48b8135 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -0,0 +1,793 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class EinsumOpBuilder : public BaseOpBuilder { + // Add operator related. + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Helper functions, thanks for DML EP's OperatorHelper. +enum class RecognizedOperatorType { + None, + Identity, + ReduceSum, + Transpose, + Diagonal, + Multiply, + Pairwise, + Total, +}; + +struct RecognizedOperatorInfo { + RecognizedOperatorType recognized_operator_type; + std::initializer_list component_ranks; + std::initializer_list label_indices; +}; + +struct Component { + uint32_t label_index_begin; + uint32_t label_index_end; + + uint32_t GetDimensionCount() const noexcept { + return label_index_end - label_index_begin; + } + gsl::span GetLabels(gsl::span labels) const { + return labels.subspan(label_index_begin, label_index_end - label_index_begin); + } +}; + +bool ParseEquationComponents(const Node& node, + const std::string_view equation, + std::vector& label_indices, + std::vector& components, + std::vector& output_dimensions, + uint32_t& num_labels, + const logging::Logger& logger) { + // Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to + // numeric indices {(0,1}, {1,2}, {0,2}}. The last component is the output. + // Read first to last character in equation, looking for letters, commas, and one arrow. + // The ellipsis is not supported. + std::map label_maps; + std::set repeated_labels; + + num_labels = 0; + Component current_component = {}; + bool at_output = false; + bool end_flag = false; + + for (const char* it = equation.data(); !end_flag; ++it) { + // std::string.data() promises the end of the string is '\0' + char ch = *it; + + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) { + const auto [i, inserted] = label_maps.insert({ch, num_labels}); + if (inserted) { + if (at_output) { + LOGS(logger, VERBOSE) << "Found label in equation output not matching any label from inputs."; + return false; + } + ++num_labels; + } else if (!at_output) { + repeated_labels.insert(ch); + } + label_indices.push_back(i->second); + } else if (ch == ' ') { + continue; + } else { + current_component.label_index_end = static_cast(label_indices.size()); + components.push_back(current_component); + current_component.label_index_begin = current_component.label_index_end; + + switch (ch) { + case ',': + break; + + case '-': + ++it; + if (*it != '>') { + LOGS(logger, VERBOSE) << "Expected '->' for output."; + return false; + } + if (at_output) { + LOGS(logger, VERBOSE) << "Only one output arrow '->' is valid."; + return false; + } + at_output = true; + break; + + case '.': + // Ellipsis is unsupported + LOGS(logger, VERBOSE) << "Ellipsis is unsupported."; + return false; + + case '\0': + end_flag = true; + break; // End of string. + + default: + LOGS(logger, VERBOSE) << "Unsupported character in equation string."; + return false; + } + } + } + + // If no explicit output was given, generate an implicit output by ordering all the + // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a). + // Exclude any labels that occurred more than once, as these cancel out. + if (!at_output) { + for (auto i : label_maps) { + if (repeated_labels.count(i.first) == 0) { + label_indices.push_back(i.second); + } + } + + current_component.label_index_end = static_cast(label_indices.size()); + components.push_back(current_component); + } + return true; +} + +// For two inputs A,B and one output C +Status PairwiseOperandProcess(ModelBuilder& model_builder, + const Node& node, + const std::vector& label_indices, + const std::vector& components, + const std::vector& output_dimensions, + uint32_t num_labels, + emscripten::val& output, + const logging::Logger& logger) { + auto input_a_labels = components[0].GetLabels(label_indices); + auto input_b_labels = components[1].GetLabels(label_indices); + auto output_labels = components[2].GetLabels(label_indices); + + /* + Step 1. Transpose and Reshape + + (0/1,0/1,0/1) means dim i whether appears in (A,B,C) + For new A, it has three segments [...a_1..., a_2, a_3], a_1 has multiple dims, a_2 and a_3 only have one dim respectively + For new B, it has three segments [...b_1..., b_2, b_3], b_1 has multiple dims, b_2 and b_3 only have one dim respectively + a_1 and b_1 are batch dims, and [a_2,a_3], [b_2,b_3] are for matmul + + case (1,0,0) and (0,1,0): reduce, here we treat it as batch dimension, and reduceSum at the end. + add additional dim for B/A + case (1,1,1): batch dimension, put it in the front. + case (1,0,1): gemm dim for A, put it in a_2 + case (0,1,1): gemm dim for B, put it in b_3 + case (1,1,0): summation dim / gemm dim for both A and B, put it in a_3 and b_2 + + Notes: + # of (1,1,0) maybe > 1, flatten / reshape a_3 and b_2 + # of (1,1,0) maybe = 0, add one additional dim for a_3 and b_2 + */ + + // The index in input/output of the dim index + std::map input_a_axes_map, input_b_axes_map, output_axes_map; + + for (uint32_t i = 0; i <= num_labels + 1; ++i) { + input_a_axes_map[i] = input_b_axes_map[i] = output_axes_map[i] = -1; + } + int32_t index = 0; + for (auto axis : input_a_labels) { + input_a_axes_map[axis] = index++; + } + index = 0; + for (auto axis : input_b_labels) { + input_b_axes_map[axis] = index++; + } + index = 0; + for (auto axis : output_labels) { + output_axes_map[axis] = index++; + } + + // Inputs Reshape + // a_0 = [a_1,a_2,a_3], b_0 = [b_1,b_2,b_3] + std::vector a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3; + uint32_t a_idx = input_a_labels.size(); + uint32_t b_idx = input_b_labels.size(); + bool a_flag = false; // whether a_2 has element + bool b_flag = false; // whether b_3 has element + + for (uint32_t i = 0; i < num_labels; ++i) { + if (input_a_axes_map[i] != -1) { + if (input_b_axes_map[i] != -1) { + if (output_axes_map[i] != -1) { + // The index in input/output of the dim index + a_1.push_back(i); + b_1.push_back(i); + } else { + // (1,1,0) push back in the middle for b and end for a + a_3.push_back(i); + b_2.push_back(i); + } + } else { + // (1,0,x) push back in the middle for a. If more than one, push back in the front for a, b. + if (a_flag) { + a_1.push_back(i); + b_1.push_back(i); + input_b_axes_map[i] = b_idx++; + } else { + a_2.push_back(i); + a_flag = true; + } + } + } else { + // (0,1,x) push back in the end for b. If more than one, push back in the front for a, b. + if (input_b_axes_map[i] != -1) { + if (b_flag) { + a_1.push_back(i); + b_1.push_back(i); + input_a_axes_map[i] = a_idx++; + } else { + b_3.push_back(i); + b_flag = true; + } + } + } + } + + // Matrix multiplication can be formatted in (...,i,j) * (...,j,k) ==> (...,i,k) + // Even inner and outer product can be reformatted as this. + // Inner product (1,i) * (i,1) ==> (1,1) + // Outer product (i,1) * (1,j) ==> (i,j) + // i.e., in our expression, (a_2,a_3) * (b_2,b_3) ==> (a_2,b_3) + + if (!a_flag) { + // Lack of a_2 element, add a new a_2, whose dim value = 1 + a_2.push_back(num_labels + 1); + input_a_axes_map[num_labels + 1] = a_idx++; + } + if (!b_flag) { + // Lack of b_3 element, add a new b_3, whose dim value = 1 + b_3.push_back(num_labels + 2); + input_b_axes_map[num_labels + 2] = b_idx++; + b_idx++; + } + + if (a_3.empty()) { + // Lack of a_3 and b_2 elements, add a new a_3 for A and a new b_2 for B, whose dim value = 1 + a_3.push_back(num_labels); + b_2.push_back(num_labels); + input_a_axes_map[num_labels] = a_idx; + input_b_axes_map[num_labels] = b_idx; + } + + a_0 = a_1; + b_0 = b_1; + a_0.insert(a_0.end(), a_2.begin(), a_2.end()); + a_0.insert(a_0.end(), a_3.begin(), a_3.end()); + b_0.insert(b_0.end(), b_2.begin(), b_2.end()); + b_0.insert(b_0.end(), b_3.begin(), b_3.end()); + + std::vector permutation_a, permutation_b; + for (uint32_t i = 0; i < a_0.size(); ++i) { + permutation_a.push_back(static_cast(input_a_axes_map[a_0[i]])); + permutation_b.push_back(static_cast(input_b_axes_map[b_0[i]])); + } + + const auto& input_defs = node.InputDefs(); + emscripten::val input_a = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input_b = model_builder.GetOperand(input_defs[1]->Name()); + std::vector new_a_shape, new_b_shape; + if (input_a_labels.size() < a_0.size()) { + std::vector input_a_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_a_shape, logger), "Cannot get shape"); + std::transform(input_a_shape.begin(), input_a_shape.end(), std::back_inserter(new_a_shape), + [](int64_t i) { return static_cast(i); }); + for (uint32_t i = 0; i < a_0.size() - input_a_labels.size(); ++i) { + new_a_shape.push_back(SafeInt(1)); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_a = model_builder.GetBuilder().call("reshape", + input_a, + emscripten::val::array(new_a_shape), + options); + } + if (input_b_labels.size() < b_0.size()) { + std::vector input_b_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_b_shape, logger), "Cannot get shape"); + std::transform(input_b_shape.begin(), input_b_shape.end(), std::back_inserter(new_b_shape), + [](int64_t i) { return static_cast(i); }); + for (uint32_t i = 0; i < b_0.size() - input_b_labels.size(); ++i) { + new_b_shape.push_back(SafeInt(1)); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_b = model_builder.GetBuilder().call("reshape", + input_b, + emscripten::val::array(new_b_shape), + options); + } + + // Inputs Transpose + std::vector sequence(permutation_a.size()); + std::iota(sequence.begin(), sequence.end(), 0); + if (permutation_a != sequence) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation_a)); + options.set("label", node.Name() + "_transpose"); + input_a = model_builder.GetBuilder().call("transpose", input_a, options); + } + if (permutation_b != sequence) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation_b)); + options.set("label", node.Name() + "_transpose"); + input_b = model_builder.GetBuilder().call("transpose", input_b, options); + } + + // Input Reshape: if the number of (1,1,0) > 1, flatten the b_2 and a_3 dims. + if (a_3.size() > 1) { + if (new_a_shape.empty()) { + std::vector input_a_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_a_shape, logger), "Cannot get shape"); + std::transform(input_a_shape.begin(), input_a_shape.end(), std::back_inserter(new_a_shape), + [](int64_t i) { return static_cast(i); }); + } + if (new_b_shape.empty()) { + std::vector input_b_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_b_shape, logger), "Cannot get shape"); + std::transform(input_b_shape.begin(), input_b_shape.end(), std::back_inserter(new_b_shape), + [](int64_t i) { return static_cast(i); }); + } + std::vector new_new_a_shape, new_new_b_shape; + uint32_t a_dim = 1, b_dim = 1; + for (auto idx : a_1) { + new_new_a_shape.push_back(new_a_shape[idx]); + } + for (auto idx : a_2) { + new_new_a_shape.push_back(new_a_shape[idx]); + } + for (auto idx : a_3) { + a_dim *= new_a_shape[idx]; + } + new_new_a_shape.push_back(a_dim); + for (auto idx : b_1) { + new_new_b_shape.push_back(new_b_shape[idx]); + } + for (auto idx : b_2) { + b_dim *= new_b_shape[idx]; + } + new_new_b_shape.push_back(b_dim); + for (auto idx : b_3) { + new_new_b_shape.push_back(new_b_shape[idx]); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_a = model_builder.GetBuilder().call("reshape", + input_a, + emscripten::val::array(new_new_a_shape), + options); + input_b = model_builder.GetBuilder().call("reshape", + input_b, + emscripten::val::array(new_b_shape), + options); + } + + // Step 2. Matmul + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_matmul"); + output = model_builder.GetBuilder().call("matmul", input_a, input_b, options); + std::vector output_indices = a_1; + output_indices.push_back(a_2.back()); + output_indices.push_back(b_3.back()); + + /* + Step 3. Output Transpose: + Use the following fast permutation calculation algorithm + to calculate the permutation of transpose. + sequence x[] -> sequence y[] : permutation p[] + x[s[i]] = i, y[t[i]] = i, p[t[i]] = s[i] + output_indices is x and target_output_indices is y + */ + std::vector target_output_indices(output_labels.begin(), output_labels.end()); + + // map output dim labels to 0 ~ n-1 + std::vector output_indices_sorted(output_indices.begin(), output_indices.end()); + std::map mapping; + std::sort(output_indices_sorted.begin(), output_indices_sorted.end()); + for (size_t i = 0; i < output_indices_sorted.size(); i++) { + mapping[output_indices_sorted[i]] = i; + } + + for (size_t i = 0; i < output_indices.size(); i++) { + output_indices[i] = mapping[output_indices[i]]; + if (i < target_output_indices.size()) { + target_output_indices[i] = mapping[target_output_indices[i]]; + } + } + + uint32_t pad = target_output_indices.size(); + std::vector s(output_indices.size(), -1); + std::vector t(output_indices.size(), -1); + std::vector p(output_indices.size(), 0); + for (uint32_t i = 0; i < output_indices.size(); ++i) { + s[output_indices[i]] = i; + if (i < target_output_indices.size()) { + t[target_output_indices[i]] = i; + } + } + for (uint32_t i = 0; i < output_indices.size(); ++i) { + if (t[i] == -1) { + t[i] = pad++; + } + p[static_cast(t[i])] = static_cast(s[i]); + } + + std::vector sequence_o(output_indices.size()); + std::iota(sequence_o.begin(), sequence_o.end(), 0); + if (p != sequence_o) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(p)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } + + // Step 4. Output ReduceSum + if (output_labels.size() < output_indices.size()) { + std::vector axes_data; + for (uint32_t i = output_labels.size(); i < output_indices.size(); ++i) { + axes_data.push_back(SafeInt(i)); + } + emscripten::val options_reduce = emscripten::val::object(); + options_reduce.set("axes", emscripten::val::array(axes_data)); + options_reduce.set("label", node.Name() + "_reduceSum"); + output = model_builder.GetBuilder().call("reduceSum", output, options_reduce); + } + return Status::OK(); +} + +RecognizedOperatorType DetermineRecognizedOperatorType(const std::vector& label_indices, + const std::vector& components, + const std::vector& output_dimensions) { + if (components.empty()) return RecognizedOperatorType::None; + + auto equals = [](gsl::span a, gsl::span b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); + }; + + std::array component_ranks; + if (components.size() > component_ranks.size()) { + // So far, not support for more than two inputs and one output. + return RecognizedOperatorType::None; + } else if (components.size() == 2) { // one input + auto input_labels = components[0].GetLabels(label_indices); + auto output_labels = components[1].GetLabels(label_indices); + if (input_labels.size() == output_labels.size()) { + if (equals(input_labels, output_labels)) { + // Identity: input labels = output labels + return RecognizedOperatorType::Identity; + } else { + return RecognizedOperatorType::Transpose; + } + } else if (input_labels.size() == input_labels.back() + 1) { + // ReduceSum: There is no repeated character in input. + return RecognizedOperatorType::ReduceSum; + } else if (input_labels.size() == input_labels.back() + 2) { + // Diagonal: One repeated character in input, ii->i / iij->ij / iijk -> ijk. + return RecognizedOperatorType::Diagonal; + } else { + return RecognizedOperatorType::None; + } + } else if (components.size() == 3) { // two inputs + auto input_A_labels = components[0].GetLabels(label_indices); + auto input_B_labels = components[1].GetLabels(label_indices); + auto output_labels = components[2].GetLabels(label_indices); + if (equals(input_A_labels, output_labels) && equals(input_B_labels, output_labels)) { // element-wise product + return RecognizedOperatorType::Multiply; + } + } + + return RecognizedOperatorType::Pairwise; +} + +// Add operator related. + +Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val output = emscripten::val::object(); + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + ORT_RETURN_IF_NOT(ParseEquationComponents(node, equation, label_indices, components, output_dimensions, + num_labels, logger), + "Error parsing equation components."); + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + + switch (recognized_operator_type) { + case RecognizedOperatorType::Multiply: { + emscripten::val a = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val b = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_mul"); + output = model_builder.GetBuilder().call("mul", a, b, options); + } break; + case RecognizedOperatorType::ReduceSum: { + auto kept_axes = components.back().GetLabels(label_indices); + std::vector reduced_axes; + uint32_t kept_axes_mask = 0; + for (auto axis : kept_axes) { + kept_axes_mask |= (1 << axis); + } + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (uint32_t axis = 0, axis_count = static_cast(input_shape.size()); axis < axis_count; ++axis) { + if (~kept_axes_mask & (1 << axis)) { + reduced_axes.push_back(axis); + } + } + + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("keepDimensions", false); + options.set("axes", emscripten::val::array(reduced_axes)); + options.set("label", node.Name() + "_reduceSum"); + + output = model_builder.GetBuilder().call("reduceSum", input, options); + + // transpose output + std::vector output_labels_sorted(kept_axes.begin(), kept_axes.end()); + std::map mapping; + std::sort(output_labels_sorted.begin(), output_labels_sorted.end()); + + auto equals = [](std::vector a, gsl::span b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); + }; + if (equals(output_labels_sorted, kept_axes)) { + break; + } + + for (size_t i = 0; i < output_labels_sorted.size(); i++) { + mapping[output_labels_sorted[i]] = i; + } + std::vector permutation; + for (auto idx : kept_axes) { + permutation.push_back(mapping[idx]); + } + emscripten::val options_transpose = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } break; + case RecognizedOperatorType::Diagonal: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + auto input_labels = components[0].GetLabels(label_indices); + auto output_labels = components[1].GetLabels(label_indices); + uint32_t diagonal_idx_1, diagonal_idx_2; + uint32_t permutation_idx = 0; + for (uint32_t idx = 0; idx < input_labels.size(); idx++) { + if (idx != input_labels[idx]) { + diagonal_idx_1 = input_labels[idx]; + diagonal_idx_2 = idx; + break; + } + } + + // tranpose input + std::vector permutation(input_labels.size()); + for (uint32_t idx = 0; idx < input_labels.size(); idx++) { + if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { + permutation[permutation_idx++] = idx; + } + } + permutation[permutation_idx++] = diagonal_idx_1; + permutation[permutation_idx] = diagonal_idx_2; + + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", input, options); + + // triu + tril = diagonal + emscripten::val options_trilu = emscripten::val::object(); + options_trilu.set("label", node.Name() + "_triangular"); + output = model_builder.GetBuilder().call("triangular", output, options_trilu); // triu + options_trilu.set("upper", false); + output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril + + // reducesum to achieve the diagonal values + std::vector input_shape; + std::vector reduced_axes; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + if (input_shape[diagonal_idx_1] > input_shape[diagonal_idx_2]) { + reduced_axes.push_back(input_labels.size() - 2); + } else { + reduced_axes.push_back(input_labels.size() - 1); + } + emscripten::val options_reduce = emscripten::val::object(); + options_reduce.set("keepDimensions", false); + options_reduce.set("axes", emscripten::val::array(reduced_axes)); + options_reduce.set("label", node.Name() + "_reduceSum"); + output = model_builder.GetBuilder().call("reduceSum", output, options_reduce); // triu + + // transpose output + std::vector target_output_indices(output_labels.begin(), output_labels.end()); + std::vector output_indices(permutation.begin(), permutation.end() - 1); + + // Use the fast permutation calculation algorithm mentioned above + std::vector s(output_indices.size(), -1); + std::vector t(output_indices.size(), -1); + std::vector p(output_indices.size(), 0); + for (uint32_t i = 0; i < output_indices.size(); ++i) { + s[output_indices[i]] = i; + t[target_output_indices[i]] = i; + } + for (uint32_t i = 0; i < output_indices.size(); ++i) { + p[static_cast(t[i])] = static_cast(s[i]); + } + + std::vector sequence_o(output_indices.size()); + std::iota(sequence_o.begin(), sequence_o.end(), 0); + if (p != sequence_o) { + emscripten::val options_transpose = emscripten::val::object(); + options.set("permutation", emscripten::val::array(p)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } + } break; + + case RecognizedOperatorType::Transpose: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + assert(components.front().GetDimensionCount() == components.back().GetDimensionCount()); + // Remap transposed strides using the component labels from input to output. + auto output_labels = components.back().GetLabels(label_indices); + + std::vector permutation{output_labels.begin(), output_labels.end()}; + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", input, options); + } break; + + case RecognizedOperatorType::Identity: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + output = input; + } break; + + case RecognizedOperatorType::Pairwise: { + ORT_RETURN_IF_ERROR(PairwiseOperandProcess(model_builder, node, label_indices, components, + output_dimensions, num_labels, output, logger)); + } break; + + default: + break; + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, + const Node& node, + const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + + if (!ParseEquationComponents(node, equation, label_indices, components, + output_dimensions, num_labels, logger)) { + LOGS(logger, VERBOSE) << "EinSum input equation is illegal."; + return false; + } + + if (static_cast(input_defs.size()) + 1 != components.size()) { + LOGS(logger, VERBOSE) << "EinSum input tensor count is inconsistent with the equation component count."; + return false; + } + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } + + return true; +} + +bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + const auto& op_type = node.OpType(); + int32_t input0_type; + int32_t input1_type; + bool has_input1 = input_defs.size() > 1 && input_defs[1]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + (has_input1 && !GetType(*input_defs[1], input1_type, logger))) { + return false; + } + + if (has_input1 && input0_type != input1_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + + if (!ParseEquationComponents(node, equation, label_indices, + components, output_dimensions, num_labels, logger)) { + LOGS(logger, VERBOSE) << "EinSum input equation is illegal."; + return false; + } + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { + // Map to WebNN's gemm or matmul + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); + } else { + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); + } +} + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index c8cea833983b1..f5e1f59602c5d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -88,6 +88,10 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers LOGS(logger, VERBOSE) << "Cannot get shape."; return false; } + if (std::any_of(new_shape.begin(), new_shape.end(), [](int64_t dimension) { return dimension == 0; })) { + LOGS(logger, VERBOSE) << "WebNN expand does not support new shape with 0 dimension."; + return false; + } std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) { @@ -95,11 +99,6 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Expand does not support empty input's shape."; - return false; - } - std::vector output_shape; if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) { LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc new file mode 100644 index 0000000000000..cb7b7de74e121 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = model_builder.GetBuilder().call("gatherElements", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc new file mode 100644 index 0000000000000..002a1a6a63026 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("gatherND", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("batch_dims", 0) != 0) { + LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)"; + return false; + } + + return true; +} + +bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index ae9fe3e3f3bd1..88d22f103cadc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 1477530ce1894..5f4e6de8fda98 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + a_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + b_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("matmulInteger", a, @@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // A data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index c92fe7366d494..b240e30d38b22 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,8 +26,10 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -185,44 +187,68 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - int32_t input0_type = 0; // input data type - int32_t input1_type = 0; // weight data type - int32_t input2_type = 0; // recurrentWeight data type - int32_t input3_type = 0; // bias data type - int32_t input4_type = 0; // recurrentBias data type - int32_t input5_type = 0; // initialHiddenState data type - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); - bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - !GetType(*input_defs[2], input2_type, logger) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || - (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + int32_t input_X_type = 0; // input data type + int32_t input_W_type = 0; // weight data type + int32_t input_R_type = 0; // recurrent weight data type + int32_t input_B_type = 0; // bias data type + int32_t input_initial_h_type = 0; // initial hidden state data type + bool has_input_B = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input_initial_h = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input_X_type, logger) || + !GetType(*input_defs[1], input_W_type, logger) || + !GetType(*input_defs[2], input_R_type, logger) || + (has_input_B && !GetType(*input_defs[3], input_B_type, logger)) || + // input_defs[4] refers to sequence_lens and is a fixed data type of int32. + (has_input_initial_h && !GetType(*input_defs[5], input_initial_h_type, logger))) { return false; } - InlinedVector input_types = {input0_type, input1_type, input2_type}; - if (has_input3) { - input_types.push_back(input3_type); + InlinedVector input_types = {input_X_type, input_W_type, input_R_type}; + if (has_input_B) { + input_types.push_back(input_B_type); } - if (has_input4) { - input_types.push_back(input4_type); - } - if (has_input5) { - input_types.push_back(input5_type); + if (has_input_initial_h) { + input_types.push_back(input_initial_h_type); } if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); +} + +bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t Y_type = 0; + int32_t Y_h_type = 0; + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + + bool Y_supported = has_Y && GetType(*output_defs[0], Y_type, logger); + bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger); + + if (Y_supported && !Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else if (!Y_supported && Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + } else if (Y_supported && Y_h_supported) { + if (Y_type != Y_h_type) { + LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same."; + return false; + } + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else { + LOGS(logger, VERBOSE) << "[GRU] No output found."; + return false; + } } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index ea7f70b4598e6..91910f55f37c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -33,28 +33,20 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons const auto& op_type = node.OpType(); emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val input1 = emscripten::val::undefined(); - if (input_defs.size() > 1) { - input1 = model_builder.GetOperand(input_defs[1]->Name()); - } emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - if (op_type == "Equal") { - output = model_builder.GetBuilder().call("equal", input0, input1, options); - } else if (op_type == "Greater") { - output = model_builder.GetBuilder().call("greater", input0, input1, options); - } else if (op_type == "GreaterOrEqual") { - output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1, options); - } else if (op_type == "Less") { - output = model_builder.GetBuilder().call("lesser", input0, input1, options); - } else if (op_type == "LessOrEqual") { - output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input0, options); + + std::string webnn_op_type; + ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); + + if (input_defs.size() == 1) { + // Not + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input0, options); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + input1 = model_builder.GetOperand(input_defs[1]->Name()); + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input0, input1, options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -68,16 +60,19 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2 && op_type != "Not") { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " - << input_defs.size(); + + size_t expected_input_count = (op_type == "Not") ? 1 : 2; + if (input_defs.size() != expected_input_count) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] expected input count: " + << expected_input_count << ", actual: " << input_defs.size(); return false; } + return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; @@ -105,12 +100,15 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& static std::vector op_types = { + "And", "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", "Not", + "Or", + "Xor", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc new file mode 100644 index 0000000000000..19f6d6aff8f97 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class LRNOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + int32_t input_data_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + const auto node_name = node.Name(); + emscripten::val wnn_builder = model_builder.GetBuilder(); + + NodeAttrHelper helper(node); + const float alpha = helper.Get("alpha", 0.0001f); + const float beta = helper.Get("beta", 0.75f); + const float bias = helper.Get("bias", 1.0f); + const uint32_t size = helper.Get("size", 1); + + // Prepare WebNN constants for alpha, beta, bias attributes. + // Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'. + emscripten::val alpha_constant = model_builder.CreateOrGetConstant(input_data_type, alpha); + emscripten::val beta_constant = model_builder.CreateOrGetConstant(input_data_type, beta); + emscripten::val bias_constant = model_builder.CreateOrGetConstant(input_data_type, bias); + emscripten::val pow1_constant = model_builder.CreateOrGetConstant(input_data_type, 2); + + /** + WebNN doesn't support LRN. So decompose it into a series of ops: + X --> Pow --> (Transpose)--> Pad --> AveragePool--> (Transpose) --> Mul --> Add --> Pow --> Div + ^ ^ ^ ^ ^ ^ ^ ^ + | | | | | | | | + Y:2 (0,2,3,1) Kernel:(1,size) (0,3,1,2) B:alpha B:bias B:beta A:input + */ + // + // pow(input, 2) + emscripten::val label_options = emscripten::val::object(); + label_options.set("label", node_name + "_pow1"); + emscripten::val pow1_output = wnn_builder.call("pow", input, pow1_constant, label_options); + + // transpose(pow1_output, permutation=[0, 2, 3, 1]) + // LRN is one of NHWC layout sensitive ops. When preferred layout is NCHW, move dimension 1 to dimension 3 (rightmost). + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 2, 3, 1}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_rightmost"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pow1_output = + wnn_builder.call("transpose", pow1_output, transpose_options); + } + + // pad(pow1_output, beginning_padding = {0, 0, 0, leading_padding}, ending_padding = {0, 0, 0, trailing_padding}) + // Adding a Pad before averagePool2d and calling AveragePool with pads as 0's. + const uint32_t leading_padding = floor((size - 1) / 2); + const uint32_t trailing_padding = ceil((size - 1) / 2); + std::vector beginning_padding{0, 0, 0, leading_padding}; + std::vector ending_padding{0, 0, 0, trailing_padding}; + emscripten::val pad_options = emscripten::val::object(); + pad_options.set("label", node_name + "_pad"); + emscripten::val pad_output = + wnn_builder.call("pad", pow1_output, emscripten::val::array(beginning_padding), + emscripten::val::array(ending_padding), pad_options); + + // averagePool2d(pad_output, pool_options) + const std::vector kernel_shape = {1, size}; + emscripten::val pool_options = emscripten::val::object(); + pool_options.set("label", node_name + "_averagePool2d"); + pool_options.set("windowDimensions", emscripten::val::array(kernel_shape)); + emscripten::val pool_output = wnn_builder.call("averagePool2d", pad_output, pool_options); + + // transpose(pool_output, permutation=[0, 3, 1, 2]) + // Move dimension 3 back to dimension 1. + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 3, 1, 2}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_inverse"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pool_output = + wnn_builder.call("transpose", pool_output, transpose_options); + } + + // mul(pool_output, alpha_constant) + label_options.set("label", node_name + "_mul"); + emscripten::val mul_output = + wnn_builder.call("mul", pool_output, alpha_constant, label_options); + + // add(mul_output, bias_constant) + label_options.set("label", node_name + "_add"); + emscripten::val add_output = wnn_builder.call("add", mul_output, bias_constant, label_options); + + // pow(add_output, beta_constant) + label_options.set("label", node_name + "_pow2"); + emscripten::val pow2_output = wnn_builder.call("pow", add_output, beta_constant, label_options); + + // div(input, pow2_output) + label_options.set("label", node_name + "_div"); + emscripten::val div_output = wnn_builder.call("div", input, pow2_output, label_options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(div_output)); + return Status::OK(); +} + +// Operator support related. +bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "LRN only supports 4D input shape, input is " + << input_size << "D shape"; + return false; + } + + return true; +} + +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 6213b039fb2f9..33ba22ac3fb5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index e111ca412c6e9..40f94186e9ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index a3c6b8fdcea9b..50e49884bdfa9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -72,7 +72,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } NodeAttrHelper helper(node); - options.set("epsilon", helper.Get("epsilon", 1e-05f)); + const auto epsilon = helper.Get("epsilon", 1e-05f); + options.set("epsilon", epsilon); emscripten::val output = emscripten::val::undefined(); if (op_type == "BatchNormalization") { @@ -84,14 +85,59 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); - } else if (op_type == "LayerNormalization") { + } else if (op_type == "LayerNormalization" || op_type == "SimplifiedLayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); std::iota(axes.begin(), axes.end(), axis); - options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("layerNormalization", input, options); + if (op_type == "LayerNormalization") { + options.set("axes", emscripten::val::array(axes)); + output = model_builder.GetBuilder().call("layerNormalization", input, options); + } else { // SimplifiedLayerNormalization + /** + WebNN doesn't support SimplifiedLayerNormalization. So decompose it into a series of ops: + X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul + ^ ^ ^ ^ ^ + | | | | | + Y:2 axis B:epsilon A:X A:scale + */ + + int32_t input_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type"); + emscripten::val common_options = emscripten::val::object(); + + // Pow + emscripten::val pow_constant = model_builder.CreateOrGetConstant(input_type, 2); + common_options.set("label", node.Name() + "_pow"); + emscripten::val pow = + model_builder.GetBuilder().call("pow", input, pow_constant, common_options); + + // ReduceMean + emscripten::val reduce_options = emscripten::val::object(); + reduce_options.set("axes", emscripten::val::array(axes)); + reduce_options.set("keepDimensions", true); + reduce_options.set("label", node.Name() + "_reduceMean"); + emscripten::val reduce_mean = model_builder.GetBuilder().call("reduceMean", pow, reduce_options); + + // Add + emscripten::val add_constant = model_builder.CreateOrGetConstant(input_type, epsilon); + common_options.set("label", node.Name() + "_add"); + emscripten::val add = + model_builder.GetBuilder().call("add", reduce_mean, add_constant, common_options); + + // Sqrt + common_options.set("label", node.Name() + "_sqrt"); + emscripten::val sqrt = model_builder.GetBuilder().call("sqrt", add, common_options); + + // Div + common_options.set("label", node.Name() + "_div"); + emscripten::val div = model_builder.GetBuilder().call("div", input, sqrt, common_options); + + // Mul + common_options.set("label", node.Name() + "_mul"); + output = model_builder.GetBuilder().call("mul", scale, div, common_options); + } } else if (op_type == "InstanceNormalization") { // WebNN spec only supports 4D input for instanceNormalization. // Supports 3D input by prepending 1 size dimension. @@ -182,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -229,6 +276,7 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat "BatchNormalization", "InstanceNormalization", "LayerNormalization", + "SimplifiedLayerNormalization", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 13dee667f6fd9..bd7c23d75eba4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -22,8 +22,10 @@ class QDQOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -35,85 +37,123 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; std::vector scale_shape; + std::vector zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); int32_t input_type = 0; int32_t output_type = 0; int32_t zero_point_type = 0; + bool has_zero_point = false; ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type"); ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val zero_point = emscripten::val::null(); + if (input_defs.size() == 3 && input_defs[2]->Exists()) { zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + has_zero_point = true; } else { // DequantizeLinear: x_zero_point's data type equals to input data type // QuantizeLinear: x_zero_point's data type equals to output data type zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type; - zero_point = model_builder.GetZeroConstant(zero_point_type); } - emscripten::val output; + const auto input_rank = input_shape.size(); NodeAttrHelper helper(node); - int32_t axis = helper.Get("axis", 1); int32_t block_size = helper.Get("block_size", 0); - // axis is valid for input shape greater than 1D. - if (input_shape.size() > 1) { - axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); + int32_t axis = helper.Get("axis", 1); + if (axis < 0) { + axis = SafeInt(HandleNegativeAxis(axis, input_rank)); } - // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. - if (1 == scale_shape.size() && 1 < input_shape.size()) { - std::vector target_shape{static_cast(input_shape[axis])}; + + // For per-axis quantization/dequantization and axis is not equal to input_rank - 1, + // we need to reshape the scale and zero_point tensors to make them broadcastable with the input tensor. + if (scale_shape.size() == 1 && input_rank > 1 && + block_size == 0 && axis != static_cast(input_rank - 1)) { + // Insert ones before and after the axis dimension for broadcasting of scale tensor. + std::vector target_shape{SafeInt(input_shape[axis])}; target_shape.insert(target_shape.begin(), axis, 1); - target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); + target_shape.insert(target_shape.end(), input_rank - axis - 1, 1); + // zero_point has the same shape as the scale tensor. + zero_point_shape = target_shape; emscripten::val reshape_scale_options = emscripten::val::object(); reshape_scale_options.set("label", node.Name() + "_reshape_scale"); scale = model_builder.GetBuilder().call("reshape", scale, emscripten::val::array(target_shape), reshape_scale_options); - emscripten::val reshape_zero_point_options = emscripten::val::object(); - reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); - zero_point = model_builder.GetBuilder().call("reshape", - zero_point, - emscripten::val::array(target_shape), - reshape_zero_point_options); - } - // If block_size is specified, we need to expand the scale and zero_point tensors. - if (block_size > 1) { - emscripten::val concat_scale_inputs = emscripten::val::array(); - emscripten::val concat_zero_point_inputs = emscripten::val::array(); - for (int i = 0; i < block_size; i++) { - concat_scale_inputs.call("push", scale); - concat_zero_point_inputs.call("push", zero_point); + if (has_zero_point) { + // Reshape the zero_point tensor too. + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); + zero_point = model_builder.GetBuilder().call("reshape", + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); } + } - emscripten::val concat_scale_options = emscripten::val::object(); - concat_scale_options.set("label", node.Name() + "_concat_scale"); - scale = model_builder.GetBuilder().call("concat", concat_scale_inputs, axis, concat_scale_options); - - emscripten::val concat_zero_point_options = emscripten::val::object(); - concat_zero_point_options.set("label", node.Name() + "_concat_zero_point"); - zero_point = model_builder.GetBuilder().call( - "concat", concat_zero_point_inputs, axis, concat_zero_point_options); + // If zero_point is not provided, create a zero constant with the same shape as the scale tensor. + if (!has_zero_point) { + if (zero_point_shape.empty()) { + // zero_point has the same shape as the scale tensor. + zero_point_shape = GetVecUint32FromVecInt64(scale_shape); + } + // Create a zero constant with the same shape as the scale tensor. + // The zero value has been pre-processed in the CreateOrGetConstant function, + // so the type of T is not relevant here. + zero_point = model_builder.CreateOrGetConstant(zero_point_type, 0, zero_point_shape); } emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); std::string webnn_op_type; ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); - output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); + emscripten::val output = + model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); return Status::OK(); } -bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +// Operator support related. +bool QDQOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + std::vector scale_shape; + + if (!GetShape(*input_defs[0], input_shape, logger) || !GetShape(*input_defs[1], scale_shape, logger)) { + return false; + } + + // WebNN requires the scale_shape to be a subsample of the input_shape. + if (scale_shape.size() > input_shape.size()) { + LOGS(logger, VERBOSE) << "The rank of scale is larger than the rank of input"; + return false; + } + + for (size_t i = 0; i < scale_shape.size(); ++i) { + auto scale_dim = scale_shape[scale_shape.size() - i - 1]; + auto input_dim = input_shape[input_shape.size() - i - 1]; + if (input_dim % scale_dim != 0) { + LOGS(logger, VERBOSE) << "The shape of scale is not a subsample of the shape of input"; + return false; + } + } + + return true; +} + +bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index a7911683f0355..0a438e98ad737 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -44,21 +44,25 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); - const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() - ? reinterpret_cast(target_shape_tensor.raw_data().data()) - : target_shape_tensor.int64_data().data(); + const auto& target_shape_tensor_dims = target_shape_tensor.dims(); + std::vector new_shape; + // Do nothing if target shape is an empty shape, which means converting to a scalar. + if (!target_shape_tensor_dims.empty()) { + const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() + ? reinterpret_cast(target_shape_tensor.raw_data().data()) + : target_shape_tensor.int64_data().data(); + + const auto size = target_shape_tensor_dims[0]; + TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + ReshapeHelper helper(TensorShape(input_shape), target_shape); + std::transform(target_shape.cbegin(), target_shape.cend(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + } - const auto size = target_shape_tensor.dims()[0]; - TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - ReshapeHelper helper(TensorShape(input_shape), target_shape); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - std::vector new_shape; - std::transform(target_shape.cbegin(), target_shape.cend(), - std::back_inserter(new_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("reshape", @@ -76,6 +80,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto& perm_name = input_defs[1]->Name(); if (!Contains(initializers, perm_name)) { LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; @@ -92,24 +101,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer const int64_t* raw_new_shape = reinterpret_cast(unpacked_tensor.data()); const auto& perm_dims = perm_tensor.dims(); - if (perm_dims.empty() || perm_dims[0] == 0) { - LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; - return false; - } - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Reshape does not support empty input shape"; - return false; - } // WebNN reshape does not support 0 as dimension. NodeAttrHelper helper(node); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; - if (allow_zero) { + const bool allow_zero = helper.Get("allowzero", 0) == 1; + if (allow_zero && !perm_dims.empty()) { for (int64_t i = 0; i < perm_dims[0]; i++) { if (raw_new_shape[i] == 0) { LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 9dc79f4f52f46..00f8cff25ccf5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -21,6 +21,8 @@ namespace webnn { class ResizeOpBuilder : public BaseOpBuilder { // Add operator related. public: + // Allow roi and scales potentially being empty inputs that are ignored during processing. + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; private: @@ -38,16 +40,33 @@ class ResizeOpBuilder : public BaseOpBuilder { }; // Helper functions -bool GetResizeScales(const InitializedTensorSet& initializers, - const Node& node, std::vector& scales, - const logging::Logger& logger) { +bool GetResizeScalesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& scales, + std::vector& axes, const bool is_nhwc, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 3) return false; + const bool has_axes = !axes.empty(); const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); - if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + if (scales_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'scales' should be a 1D tensor."; return false; + } + + // Number of elements of 'scales' tensor. + const auto num_of_scales = scales_tensor.dims()[0]; + + if (has_axes && num_of_scales != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'scales' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_scales != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'scales' should have 4 elements."; + return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); @@ -56,20 +75,65 @@ bool GetResizeScales(const InitializedTensorSet& initializers, return false; } const float* scales_data = reinterpret_cast(unpacked_tensor.data()); - scales = std::vector{scales_data, scales_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'scales' should have 2 elements. + scales = std::vector{scales_data, scales_data + 2}; + } else { + // Before opset 18, 'scales' should have 4 elements. + // Make sure 'scales' is not trying to scale on N/C channels here. + std::vector onnx_scales{scales_data, scales_data + 4}; + // 'scales' input has been transposed to NHWC layout if it is NHWC preferred layout. + const float scale_n = onnx_scales[0]; + const float scale_c = is_nhwc ? onnx_scales[3] : onnx_scales[1]; + const float scale_h = is_nhwc ? onnx_scales[1] : onnx_scales[2]; + const float scale_w = is_nhwc ? onnx_scales[2] : onnx_scales[3]; + if (scale_n != 1.0f || scale_c != 1.0f) { + LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" + << "Scales of N/C channels are not supported" + << ", scale_n, " << scale_n << ", scale_c, " << scale_c; + return false; + } + + scales = {scale_h, scale_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert axes from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } -bool GetResizeOutputSizes(const InitializedTensorSet& initializers, - const Node& node, std::vector& sizes, - const logging::Logger& logger) { +bool GetResizeSizesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& sizes, + std::vector& axes, const bool is_nhwc, + const gsl::span& input_shape, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 4) return false; + const bool has_axes = !axes.empty(); const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); - if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + if (sizes_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'sizes' should be a 1D tensor."; return false; + } + + // Number of elements of sizes tensor. + const auto num_of_sizes = sizes_tensor.dims()[0]; + if (has_axes && num_of_sizes != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'sizes' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_sizes != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'sizes' should have 4 elements."; + return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); @@ -78,7 +142,35 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, return false; } const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); - sizes = std::vector{sizes_data, sizes_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'sizes' should have 2 elements. + sizes = std::vector{sizes_data, sizes_data + 2}; + } else { + // Before opset 18, 'sizes' should have 4 elements. + // Make sure 'sizes' is not trying to resize on N/C channels here. + std::vector onnx_sizes{sizes_data, sizes_data + 4}; + auto size_n = onnx_sizes[0]; + const int c_idx = is_nhwc ? 3 : 1; + if (size_n != input_shape[0] || onnx_sizes[c_idx] != input_shape[c_idx]) { + LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + << "Resize of N/C channels are not supported" + << ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << onnx_sizes[c_idx]; + return false; + } + // 'sizes' input has been transposed to NHWC layout if it is NHWC preferred layout. + const int64_t sizes_h = is_nhwc ? onnx_sizes[1] : onnx_sizes[2]; + const int64_t sizes_w = is_nhwc ? onnx_sizes[2] : onnx_sizes[3]; + sizes = {sizes_h, sizes_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert 'axes' from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } @@ -103,9 +195,15 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& initializers(model_builder.GetInitializerTensors()); + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { options.set("mode", emscripten::val("linear")); @@ -113,45 +211,30 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("mode", emscripten::val("nearest-neighbor")); } - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - std::vector scales; - std::vector sizes; - std::vector scales_hw; - std::vector sizes_hw; - std::vector axes; - std::string scales_name = GetTensorName(input_defs, 2); + std::vector sizes; + std::vector webnn_sizes; + std::vector axes = GetResolvedAxes(helper, 4); // We already checked input shape is 4D in IsOpSupportedImpl. + std::string sizes_name = GetTensorName(input_defs, 3); const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; - if (!scales_name.empty()) { // Use scales. - ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - if (is_nhwc) { - scales_hw = {scales[1], scales[2]}; - } else { - scales_hw = {scales[2], scales[3]}; - } - options.set("scales", emscripten::val::array(scales_hw)); - } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. - std::vector output_sizes; - ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), - "Error getting resize output_sizes"); - std::transform(output_sizes.cbegin(), output_sizes.cend(), - std::back_inserter(sizes), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); - if (is_nhwc) { - sizes_hw = {sizes[1], sizes[2]}; - } else { - sizes_hw = {sizes[2], sizes[3]}; - } - options.set("sizes", emscripten::val::array(sizes_hw)); - } - if (is_nhwc) { - axes = {1, 2}; + // We know we have either a 'scales' or 'sizes' input so this is safe. + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + bool using_sizes = !sizes_name.empty() && Contains(initializers, sizes_name); + if (using_sizes) { + ORT_RETURN_IF_NOT(GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger), + "Error getting Resize sizes"); + webnn_sizes = GetVecUint32FromVecInt64(sizes); + options.set("sizes", emscripten::val::array(webnn_sizes)); } else { - axes = {2, 3}; + ORT_RETURN_IF_NOT(GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger), + "Error getting Resize scales"); + options.set("scales", emscripten::val::array(scales)); } - options.set("axes", emscripten::val::array(axes)); + + std::vector webnn_axes = GetVecUint32FromVecInt64(axes); + options.set("axes", emscripten::val::array(webnn_axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = model_builder.GetBuilder().call("resample2d", input, options); @@ -166,6 +249,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) @@ -179,92 +263,75 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } { // Check attributes. - NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - bool is_linear_resize = mode == "linear"; - bool is_nearest_resize = mode == "nearest"; - // WebNN only supports "linear" and "nearest" modes. - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; + // antialias + if (helper.Get("antialias", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support antialias"; return false; } + // Ignore coordinate_transformation_mode because WebNN only supports half_pixel mode. + // TODO: Validate coordinate_transformation_mode. Related spec issue for supporting attribute coordinate + // transformation modes: https://github.com/webmachinelearning/webnn/issues/270 + + // exclude_outside const auto exclude_outside = helper.Get("exclude_outside", 0); if (exclude_outside != 0) { LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; return false; } - } - { // scales and sizes (if present) must be initializers. - const std::string scales_name = GetTensorName(input_defs, 2); - const std::string sizes_name = GetTensorName(input_defs, 3); - - // scales (scales may be empty tensor) - bool has_scales = !scales_name.empty(); - if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + // keep_aspect_ratio_policy + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (keep_aspect_ratio_policy != "stretch") { + LOGS(logger, VERBOSE) << "Resize does not support keep_aspect_ratio_policy: " << keep_aspect_ratio_policy; return false; } - // sizes (sizes may be empty tensor) - bool has_sizes = !sizes_name.empty(); - if (has_sizes && !Contains(initializers, sizes_name)) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; - return false; - } - - if (has_scales && has_sizes) { - LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified"; + // mode + const auto mode = helper.Get("mode", "nearest"); + bool is_linear_resize = mode == "linear"; + bool is_nearest_resize = mode == "nearest"; + // WebNN only supports "linear" and "nearest" modes. + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; return false; } + } - const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; - // We want to check if the scales or sizes are not trying to resize on N/C channels here. - if (has_scales) { // We are using scales. - std::vector scales; - if (!GetResizeScales(initializers, node, scales, logger)) - return false; - - float scale_n = scales[0]; - float scale_c = is_nhwc ? scales[3] : scales[1]; - if (scale_n != 1.0f || scale_c != 1.0f) { - LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" - << "Resize of N/C channels are not supported" - << ", scale_n, " << scale_n << ", scale_c, " << scale_c; - return false; - } + { // 'scales' and 'sizes' (if present) must be non-empty initializers. + const std::string scales_name = GetTensorName(input_defs, 2); + const std::string sizes_name = GetTensorName(input_defs, 3); - // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. - // TODO support ResizeBilinear. - float scale_h = is_nhwc ? scales[1] : scales[2]; - float scale_w = is_nhwc ? scales[2] : scales[3]; + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + // 'scales' or 'sizes' may be empty tensor. + bool using_sizes = !IsEmptyTensor(initializers, sizes_name); + bool using_scales = !using_sizes && !IsEmptyTensor(initializers, scales_name); - // Onnx spec requires scale to be a positive float, so we are not checking that here. - if (roundf(scale_h) != scale_h) { - LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; - return false; - } + if (!using_scales && !using_sizes) { + LOGS(logger, VERBOSE) << "Resize: only one of 'scales' and 'sizes' can be specified"; + return false; + } - if (roundf(scale_w) != scale_w) { - LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + // 'axes' is from opset 18 on and allows 'scales' or 'sizes' to have entries for the subset of 'axes'. + // We fill with default values if necessary so that the processing is consistent across all supported opsets. + std::vector axes = GetResolvedAxes(helper, input_size); + if (!axes.empty()) { // We have 'axes' attribute. + if (axes.size() != 2 || axes[0] >= input_size || axes[1] >= input_size) { + LOGS(logger, VERBOSE) << "Resize: invalid axes attribute"; return false; } } - if (has_sizes) { - // We are using sizes. - std::vector output_sizes; - if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; + if (using_sizes) { // We are using 'sizes'. + std::vector sizes; + if (!GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger)) { return false; - - auto output_size_n = output_sizes[0]; - const int c_idx = is_nhwc ? 3 : 1; - if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { - LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " - << "Resize of N/C channels are not supported" - << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + } + } else { // We are using 'scales'. + std::vector scales; + if (!GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc new file mode 100644 index 0000000000000..8c70525835059 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = + model_builder.GetBuilder().call("scatterElements", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterElements: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc new file mode 100644 index 0000000000000..8089b9706886f --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = + model_builder.GetBuilder().call("scatterND", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 3f0d633ac888b..41c66038c2694 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } }; @@ -40,8 +42,7 @@ void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const No } } -Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; @@ -49,9 +50,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto rank = input_shape.size(); NodeAttrHelper helper(node); - emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); - std::vector starts(rank); - std::vector sizes(rank); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); // Copy the data from the starts/ends/axes/steps initializers. std::vector input_starts; @@ -75,8 +74,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& initializers(model_builder.GetInitializerTensors()); const auto& tensor = *initializers.at(input_name); if (!ReadIntArrayFrom1DTensor(tensor, data, logger)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Data type for starts and ends inputs is not supported in this build."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type for starts and ends inputs is not supported in this build."); } return Status::OK(); @@ -88,28 +86,55 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_ERROR( SliceOp::PrepareForComputeHelper(input_starts, input_ends, input_axes, input_steps, compute_metadata)); - std::transform(compute_metadata.starts_.cbegin(), compute_metadata.starts_.cend(), - starts.begin(), - [](int64_t i) { return SafeInt(i); }); - std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(), - sizes.begin(), - [](int64_t i, int64_t j) { return SafeInt(i - j); }); + // Check if reverse op is needed. + std::vector reverse_axes; + emscripten::val reverse_output = input; + for (size_t i = 0; i < rank; ++i) { + if (compute_metadata.steps_[i] < 0) { + reverse_axes.push_back(SafeInt(i)); + compute_metadata.steps_[i] = -compute_metadata.steps_[i]; + compute_metadata.starts_[i] = input_shape[i] - 1 - compute_metadata.starts_[i]; + compute_metadata.ends_[i] = input_shape[i] - 1 - compute_metadata.ends_[i]; + } + } + if (!reverse_axes.empty()) { + emscripten::val reverse_options = emscripten::val::object(); + reverse_options.set("axes", emscripten::val::array(reverse_axes)); + reverse_options.set("label", node.Name() + "_reverse"); + reverse_output = model_builder.GetBuilder().call("reverse", input, reverse_options); + } - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); - emscripten::val output = model_builder.GetBuilder().call("slice", inputs, - emscripten::val::array(starts), - emscripten::val::array(sizes), - options); + // Check if slice op is needed. + bool is_slice_required = false; + for (size_t i = 0; i < rank; ++i) { + if (compute_metadata.steps_[i] != 1 || compute_metadata.starts_[i] != 0 || + compute_metadata.ends_[i] != input_shape[i]) { + is_slice_required = true; + break; + } + } + + emscripten::val output = reverse_output; + if (is_slice_required) { + std::vector starts = GetVecUint32FromVecInt64(compute_metadata.starts_); + std::vector steps = GetVecUint32FromVecInt64(compute_metadata.steps_); + std::vector sizes(rank); + std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(), + sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); + + emscripten::val options = emscripten::val::object(); + options.set("strides", emscripten::val::array(steps)); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("slice", reverse_output, emscripten::val::array(starts), + emscripten::val::array(sizes), options); + } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } -bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { +bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -129,39 +154,37 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, // Optional tensors (axes, steps) can be indicated by an empty name, just ignore it. const std::string input_name = GetTensorName(input_defs, i); if (!input_name.empty() && !Contains(initializers, input_name)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type - << " [" << name << "] must be known as initializer"; + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type << " [" << name + << "] must be known as initializer"; return false; } } - if (input_defs.size() == 5) { // Check steps. - const auto& steps_tensor = *initializers.at(input_defs[4]->Name()); - std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(steps_tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking steps_tensor: " << status.ErrorMessage(); + return true; +} + +bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& input = *input_defs[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // If there is step < 0, check data type support of reverse. + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + std::vector steps; + if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger)) return false; - } - const auto data_type = steps_tensor.data_type(); - // WebNN doesn't support steps other than 1. - if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - if (!std::all_of(reinterpret_cast(unpacked_tensor.data()), - reinterpret_cast(unpacked_tensor.data() + unpacked_tensor.size()), - [](int64_t i) { return i == 1; })) { - return false; - } - } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) { - if (!std::all_of(reinterpret_cast(unpacked_tensor.data()), - reinterpret_cast(unpacked_tensor.data()) + - unpacked_tensor.size() / sizeof(int32_t), - [](int32_t i) { return i == 1; })) { + if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { return false; } } } - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 4c59b694d690a..db10720f72762 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -28,6 +28,8 @@ class SplitOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -73,8 +75,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Check that the splits evenly divide. if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) { // Divide inputs into variable size outputs: - splits.insert(splits.end(), split_count - 1, gsl::narrow(input_shape[axis]) / split_count); - splits.insert(splits.end(), gsl::narrow(input_shape[axis]) % split_count); + splits.insert(splits.end(), split_count - 1, narrow(input_shape[axis]) / split_count); + splits.insert(splits.end(), narrow(input_shape[axis]) % split_count); } if (splits.empty()) { @@ -163,6 +165,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool SplitOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t output_type = 0; + + if (GetType(*output_defs[0], output_type, logger)) { + // Chromium has changed the output name of split from 'output' to 'outputs', + // to avoid breaking the existing API, we need to check both names. + std::string wnn_output_name = wnn_limits["split"]["output"].isUndefined() ? "outputs" : "output"; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, wnn_output_name, "outputs", logger); + } + + return false; +} + void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 4b6cf312074ba..c7b3129c0c85b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // condition data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 8e64e98445f03..91af452c64efd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -51,6 +51,8 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("neg", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); + } else if (op_type == "Sign") { + output = model_builder.GetBuilder().call("sign", input, options); } else if (op_type == "Sin") { output = model_builder.GetBuilder().call("sin", input, options); } else if (op_type == "Sqrt") { @@ -82,6 +84,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Log", "Neg", "Reciprocal", + "Sign", "Sin", "Sqrt", "Tan", diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index fcfdb146bff34..231b65a4d1894 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -42,6 +42,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap(tensor.buffer))}; @@ -93,6 +95,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap(tensor.buffer))}; @@ -210,6 +214,8 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = input_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements)); break; @@ -245,6 +251,8 @@ void Model::AllocateInputOutputBuffers() { const auto data_type = output_info.data_type; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements)); break; diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h index c554dcb6f6877..b8ab6677636db 100644 --- a/onnxruntime/core/providers/webnn/builders/model.h +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -6,7 +6,7 @@ #include "core/common/inlined_containers.h" #include "core/common/status.h" -#include "core/platform/ort_mutex.h" +#include #include #include @@ -35,7 +35,7 @@ class Model { const InlinedHashMap& outputs); // Mutex for exclusive lock to this model object. - OrtMutex& GetMutex() { return mutex_; } + std::mutex& GetMutex() { return mutex_; } // Input and output names in the onnx model's order. const std::vector& GetInputs() const { return inputs_; } @@ -77,7 +77,7 @@ class Model { InlinedHashMap input_map_; InlinedHashMap output_map_; - OrtMutex mutex_; + std::mutex mutex_; bool use_dispatch_; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 044baa738e8c4..e8f116d390199 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -88,11 +88,15 @@ Status ModelBuilder::RegisterInitializers() { for (const auto& pair : GetInitializerTensors()) { const auto& tensor = *pair.second; const auto& name = tensor.name(); - // Optional tensors can be indicated by an empty name, just ignore it. - if (name.empty() || Contains(skipped_initializers_, name)) + const auto& shape = tensor.dims(); + + // Ignore the following tensors: + // 1. Empty tensors: optional tensors can be indicated by an empty name. + // 2. Tensors in skipped_initializers_: These are tensors that are not used as WebNN Constants. + // Note: Scalar tensors are excluded because ONNX Runtime will optimize same scalar initializers into one. + if (name.empty() || (Contains(skipped_initializers_, name) && !shape.empty())) continue; - const auto& shape = tensor.dims(); std::vector dims; // When the shape is empty, it is scalar initializer that dims = {}; std::transform(shape.cbegin(), shape.cend(), @@ -112,56 +116,81 @@ Status ModelBuilder::RegisterInitializers() { auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; - if (tensor.has_raw_data()) { - tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + + if (utils::HasExternalData(tensor)) { + // Create WebNN Constant from external data. + std::basic_string external_file_path; + onnxruntime::FileOffsetType data_offset; + SafeInt tensor_byte_size; + ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo( + tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size)); + + auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant"); + operand = jsepRegisterMLConstant(emscripten::val(external_file_path), + static_cast(data_offset), + static_cast(tensor_byte_size), + wnn_builder_, + desc); } else { - // Store temporary unpacked_tensor. - unpacked_tensors_.push_back({}); - std::vector& unpacked_tensor = unpacked_tensors_.back(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); - tensor_ptr = reinterpret_cast(unpacked_tensor.data()); + if (tensor.has_raw_data()) { + tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + } else { + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + tensor_ptr = reinterpret_cast(unpacked_tensor.data()); + } + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4 || + data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + // For WebNN int4 and uint4 tensors are stored in Uint8Array, + // so we need to adjust the number of elements. + num_elements = (static_cast(num_elements) + 1) / 2; + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + default: + break; + } + + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached + // buffers in JS side. Simply create a copy to fix it. + operand = wnn_builder_.call("constant", desc, view.call("slice")); } - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - default: - break; - } - - // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached - // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); } else { // TODO: support other type. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -197,7 +226,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { - // dim_param free dimensions should have already been excluded by IsInputSupported(). + // dim_param free dimensions should have already been excluded by IsTensorShapeSupported(). assert(dim.has_dim_value()); dims.push_back(SafeInt(dim.dim_value())); } @@ -355,60 +384,6 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op wnn_operands_.insert(std::make_pair(name, operand)); } -// Get the zero scalar constant. -// Workaround for builer.constant(value, type) method since it has not been implemented now. -// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type -// BTW, the spec is discussing if the builer.constant(value, type) should be dropped at -// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. -const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { - std::string name = "webnn_zero_constant_" + std::to_string(data_type); - // If the operand does not exist, create it. - if (wnn_operands_.find(name) == wnn_operands_.end()) { - emscripten::val desc = emscripten::val::object(); - emscripten::val dims = emscripten::val::array(); - desc.set("dimensions", dims); - desc.set("shape", dims); - emscripten::val zero_buffer = emscripten::val::undefined(); - if (!SetWebnnDataType(desc, data_type)) { - ORT_THROW("Unsupported data type: " + std::to_string(data_type)); - } - - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - zero_buffer = emscripten::val::global("Uint8Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - zero_buffer = emscripten::val::global("Int8Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - zero_buffer = emscripten::val::global("Uint16Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - zero_buffer = emscripten::val::global("Float32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - zero_buffer = emscripten::val::global("Int32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - zero_buffer = emscripten::val::global("BigInt64Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - zero_buffer = emscripten::val::global("Uint32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - zero_buffer = emscripten::val::global("BigUint64Array").new_(1); - break; - default: - break; - } - - emscripten::val zero_constant = wnn_builder_.call("constant", desc, zero_buffer); - wnn_operands_.insert(std::make_pair(name, zero_constant)); - } - return wnn_operands_.at(name); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { skipped_initializers_.insert(tensor_name); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 13937933a0a9c..0fc2fa20670c7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -11,6 +11,7 @@ #include "core/framework/execution_provider.h" #include "core/providers/webnn/builders/helper.h" +#include #include #include @@ -38,7 +39,11 @@ class ModelBuilder { const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } void AddOperand(const std::string& name, const emscripten::val& operand); - const emscripten::val& GetZeroConstant(const int32_t& data_type); + + template + const emscripten::val& CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape = {}); + // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. std::vector> mem_persist_buffers_; @@ -98,5 +103,120 @@ class ModelBuilder { static const IOpBuilder* GetOpBuilder(const Node& node); }; +// Create or retrieve one of the following: +// - A WebNN constant MLOperand filled with the specified value, data type, and shape. +// - A WebNN scalar constant MLOperand with the specified value and data type. +// For scalar constant, it is workaround for builer.constant(type, value) method since +// it has not been implemented now. +// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value +// +// This function enforces a mapping between the data_type and the value types: +// - TensorProto_DataType_INT4 <-> int8_t +// - TensorProto_DataType_UINT4 <-> int8_t +// - TensorProto_DataType_BOOL <-> bool +// - TensorProto_DataType_UINT8 <-> uint8_t +// - TensorProto_DataType_INT8 <-> int8_t +// - TensorProto_DataType_FLOAT16 <-> float +// - TensorProto_DataType_FLOAT <-> float +// - TensorProto_DataType_INT32 <-> int32_t +// - TensorProto_DataType_INT64 <-> int64_t +// - TensorProto_DataType_UINT32 <-> uint32_t +// - TensorProto_DataType_UINT64 <-> uint64_t +template +const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape) { + std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value); + emscripten::val dims = emscripten::val::array(); + if (!shape.empty()) { + dims = emscripten::val::array(shape); + std::ostringstream name_stream; + name_stream << name; + for (const auto& dim : shape) { + name_stream << "_" << dim; + } + name = name_stream.str(); + } + + // If the operand does not exist, create it. + if (wnn_operands_.find(name) == wnn_operands_.end()) { + emscripten::val desc = emscripten::val::object(); + desc.set("shape", dims); + desc.set("dimensions", dims); + emscripten::val buffer = emscripten::val::undefined(); + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); + } + auto num_elements = Product(shape); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + // For WebNN int4 and uint4 tensors are stored in Uint8Array, + // so we need to adjust the number of elements. + num_elements = (num_elements + 1) / 2; + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type))); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + buffer = emscripten::val::global("Int8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + buffer = emscripten::val::global("Uint16Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + buffer = emscripten::val::global("Float32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + buffer = emscripten::val::global("Int32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + buffer = emscripten::val::global("Uint32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + buffer = emscripten::val::global("BigInt64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + buffer = emscripten::val::global("BigUint64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } + break; + default: + break; + } + + const emscripten::val constant = wnn_builder_.call("constant", desc, buffer); + wnn_operands_.insert(std::make_pair(name, constant)); + } + + return wnn_operands_.at(name); +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 8baa4790247ec..6d1c572128b93 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Sign", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); CreateUnaryOpBuilder("Tan", op_registrations); @@ -80,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateConcatOpBuilder("Concat", op_registrations); } + { // CumSum + CreateCumSumOpBuilder("CumSum", op_registrations); + } + { // Dropout CreateDropoutOpBuilder("Dropout", op_registrations); } @@ -90,6 +95,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); } + { // Einsum + CreateEinsumOpBuilder("Einsum", op_registrations); + } + { // Expand CreateExpandOpBuilder("Expand", op_registrations); } @@ -98,6 +107,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGatherOpBuilder("Gather", op_registrations); } + { // GatherElements + CreateGatherElementsOpBuilder("GatherElements", op_registrations); + } + + { // GatherND + CreateGatherNDOpBuilder("GatherND", op_registrations); + } + { // Flatten CreateFlattenOpBuilder("Flatten", op_registrations); } @@ -113,12 +130,19 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { } { // Logical + CreateLogicalOpBuilder("And", op_registrations); CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); CreateLogicalOpBuilder("Not", op_registrations); + CreateLogicalOpBuilder("Or", op_registrations); + CreateLogicalOpBuilder("Xor", op_registrations); + } + + { // LRN + CreateLRNOpBuilder("LRN", op_registrations); } { // LSTM @@ -134,6 +158,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateNormalizationOpBuilder("BatchNormalization", op_registrations); CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + CreateNormalizationOpBuilder("SimplifiedLayerNormalization", op_registrations); } { // Pad @@ -170,6 +195,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateResizeOpBuilder("Resize", op_registrations); } + { // ScatterElements + CreateScatterElementsOpBuilder("ScatterElements", op_registrations); + } + + { // ScatterND + CreateScatterNDOpBuilder("ScatterND", op_registrations); + } + { // Shape CreateShapeOpBuilder("Shape", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 990be04d42107..22bd6cd0cfa9f 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -26,14 +26,19 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -43,6 +48,8 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2258d1ac1cd8f..1a337e185b497 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -291,7 +291,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector lock(model->GetMutex()); + std::unique_lock lock(model->GetMutex()); InlinedHashMap outputs; outputs.reserve(model_outputs.size()); for (size_t i = 0; i < model_outputs.size(); i++) { diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc index e2d71cda68ec4..a0968dbc380cd 100644 --- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc +++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc @@ -90,6 +90,17 @@ const NodeUnit* ClipReluChecker(const NodeUnit& node_unit, } // namespace bool NodeSupportChecker::IsNodeSupported(const NodeUnit& nodeunit) { +#ifndef XNNPACK_FP16_SUPPORTED + // check whether the hardware support XNNPack FP16 + // Note. In CI, ios pipeline on ADO doesn't support XNNPack FP16. Because ADO mac pool is still x64. + const auto& inputs = nodeunit.Inputs(); + const auto& x_arg = inputs[0].node_arg; + const auto* x_type = x_arg.TypeAsProto(); + if (x_type == nullptr || x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + return false; + } +#endif + static std::unordered_map checkers{ {"Conv", Conv::IsOnnxNodeSupported}, {"ConvTranspose", ConvTranspose::IsOnnxNodeSupported}, diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index f9cb45ebc8abc..2adf8339b4b66 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -3,15 +3,19 @@ #include "utils.h" #include +#include #include +#include #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/initializer.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "onnx/defs/attr_proto_util.h" @@ -111,6 +115,10 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } +bool IsComputeTypeSupported(int32_t compute_type, const ComputeTypeSet& compute_type_set) { + return std::find(compute_type_set.begin(), compute_type_set.end(), compute_type) != compute_type_set.end(); +} + typedef std::string ONNXOpType; static const std::unordered_map qdq_to_onnx_type_map = { @@ -232,8 +240,8 @@ std::unique_ptr FuseActivation(const NodeUnit& node_un def.attributes = node_unit.GetNode().GetAttributes(); // use infinity as the default as that's what xnnpack uses if min/max are not set - float min = -INFINITY; - float max = INFINITY; + float min = -std::numeric_limits::infinity(); + float max = std::numeric_limits::infinity(); const auto& activation_type = activation.OpType(); if (activation_type == "Clip") { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index d555ee2286b84..0a80bc0450b99 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -6,14 +6,15 @@ #include #include #include -#include #include +#include #include #include "core/framework/node_unit.h" #include "core/framework/op_kernel.h" #include "core/graph/indexed_sub_graph.h" #include "core/providers/common.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "xnnpack.h" @@ -77,6 +78,20 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); +using ComputeTypeSet = std::unordered_set; +#ifdef XNNPACK_FP16_SUPPORTED +bool IsComputeTypeSupported(int32_t compute_type, + const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8}); +#else +bool IsComputeTypeSupported(int32_t compute_type, + const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8}); +#endif + using XnnpackOperator = std::unique_ptr; std::unique_ptr FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation, @@ -99,5 +114,6 @@ auto xnn_u8s8_quantize(float val, float scale, T zero_point) { auto zp = static_cast(zero_point); return static_cast(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max))); } + } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index f7b736b0ff903..a3ff3b585ae45 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -2,8 +2,12 @@ // Licensed under the MIT License. #include "gemm.h" + +#include + #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { @@ -37,7 +41,8 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra const auto* A_type = A_arg->TypeAsProto(); if (A_type == nullptr || - A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { break; } @@ -74,19 +79,26 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra supported = true; } while (false); - return supported; } Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*enable_caches*/ true) { - const auto& node{Node()}; - info.GetAttrOrDefault("alpha", &alpha_, 1.f); info.GetAttrOrDefault("beta", &beta_, 1.f); + const auto& node{Node()}; const auto& input_defs = node.InputDefs(); const auto* shapeA = input_defs[0]->Shape(); const auto* shapeB = input_defs[1]->Shape(); + + const NodeArg& X = *input_defs[0]; + auto input_dtype = X.TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + op_compute_type_ = OpComputeType::op_compute_type_fp32; + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + op_compute_type_ = OpComputeType::op_compute_type_fp16; + } + const NodeArg* C_arg = input_defs.size() == 2 ? nullptr : input_defs[2]; C_matrix_exists_ = C_arg && C_arg->Exists(); @@ -127,32 +139,49 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, // flags - 1 - for no transpose - 0 for transpose uint32_t flags = trans_B_ == CblasTrans ? 0 : XNN_FLAG_TRANSPOSE_WEIGHTS; - - float output_min = clip_min_max_ ? clip_min_max_->first : -INFINITY; - float output_max = clip_min_max_ ? clip_min_max_->second : INFINITY; - - const float* bias_Data = nullptr; - - if (C_matrix_exists_) { - bias_Data = tensor.Data(); - } - + auto code_cache = GetCodeCache(); + auto weights_cache = GetWeightsCache(); xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; - status = xnn_create_fully_connected_nc_f32( - trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels, - trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels, - trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride, - trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride, - B_->Data(), // const float* kernel, - bias_Data, // const float* bias, - output_min, output_max, - flags, - GetCodeCache(), GetWeightsCache(), - &p); + float foutput_min = clip_min_max_ ? clip_min_max_->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max_ ? clip_min_max_->second : std::numeric_limits::infinity(); + if (op_compute_type_ == OpComputeType::op_compute_type_fp32) { + const float* bias_data = nullptr; + if (C_matrix_exists_) { + bias_data = tensor.Data(); + } + status = xnn_create_fully_connected_nc_f32( + trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels, + trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels, + trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride, + trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride, + B_->Data(), // const float* kernel, + bias_data, // const float* bias, + foutput_min, foutput_max, + flags, + code_cache, weights_cache, + &p); + } else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { + const MLFloat16* bias_data = nullptr; + if (C_matrix_exists_) { + bias_data = tensor.Data(); + } + status = xnn_create_fully_connected_nc_f16( + trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels, + trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels, + trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride, + trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride, + B_->Data(), // const MLFloat16* kernel, + bias_data, // const float* bias, + foutput_min, foutput_max, + flags, + code_cache, weights_cache, + &p); + } if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_", + OpTypeToString(op_compute_type_), " returned ", status); } op0_.reset(p); @@ -169,19 +198,30 @@ Status Gemm::Compute(OpKernelContext* context) const { return Status::OK(); } - xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), - // Number of rows to multiply - trans_A_ == CblasNoTrans ? M_ : K_, - threadpool); + auto reshape_func = xnn_reshape_fully_connected_nc_f32; + if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { + reshape_func = xnn_reshape_fully_connected_nc_f16; + } + xnn_status status = reshape_func(op0_.get(), + // Number of rows to multiply + trans_A_ == CblasNoTrans ? M_ : K_, + threadpool); if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_", + OpTypeToString(op_compute_type_), " returned ", status); } - status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data(), Y->MutableData()); + status = xnn_status_invalid_state; + if (op_compute_type_ == op_compute_type_fp32) { + status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data(), Y->MutableData()); + } else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_fully_connected_nc_f16(op0_.get(), A->Data(), Y->MutableData()); + } if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_", + OpTypeToString(op_compute_type_), " returned ", status); } status = xnn_run_operator(op0_.get(), nullptr); @@ -193,19 +233,23 @@ Status Gemm::Compute(OpKernelContext* context) const { } ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 8, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gemm); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 9, 10, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gemm); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 11, 12, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gemm); ONNX_OPERATOR_KERNEL_EX(Gemm, kOnnxDomain, 13, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Gemm); } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index 6d11a8531c20f..954aab0698b9c 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -41,6 +41,8 @@ class Gemm : protected GemmBase, public XnnpackKernel { float alpha_; float beta_; + + OpComputeType op_compute_type_ = OpComputeType::op_compute_type_invalid; }; } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index e90aa11c9d087..f574238195ffd 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -2,7 +2,9 @@ // Licensed under the MIT License. #include "matmul.h" +#include #include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/xnnpack/xnnpack_init.h" // Todo - // 1. Integrate activation layers - Cliping & Relu @@ -34,7 +36,8 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g const auto* A_shape = A_arg.Shape(); const auto* B_shape = B_arg.Shape(); - if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { break; } @@ -62,7 +65,18 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g return supported; } -MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ true) {} +MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ true) { + const auto& node{Node()}; + const auto& input_defs = node.InputDefs(); + const NodeArg& X = *input_defs[0]; + auto input_dtype = X.TypeAsProto()->tensor_type().elem_type(); + op_type_str_ = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto())); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + op_type_ = OpComputeType::op_compute_type_fp32; + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + op_type_ = OpComputeType::op_compute_type_fp16; + } +} Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -78,8 +92,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, is_packed = true; uint32_t flags = XNN_FLAG_TRANSPOSE_WEIGHTS; - float output_min = -INFINITY; - float output_max = INFINITY; + xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; @@ -88,27 +101,49 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, if (b_shape_.NumDimensions() == 1) { shape_broadcast.push_back(1); } - status = xnn_create_fully_connected_nc_f32( - shape_broadcast[0], // size_t input_channels, - shape_broadcast[1], // size_t output_channels, - shape_broadcast[0], // size_t input_stride, - shape_broadcast[1], // size_t output_stride, - tensor.Data(), // const float* kernel, - nullptr, // const float* bias, - output_min, - output_max, - flags, + #ifdef XNN_CACHE_ENABLE - GetCodeCache(), - GetWeightsCache(), + xnn_code_cache_t code_cache = GetCodeCache(); + xnn_weights_cache_t weight_cache = GetWeightsCache(); #else - nullptr, - nullptr, + xnn_code_cache_t code_cache = nullptr; + xnn_weights_cache_t weight_cache = nullptr; #endif - &p); + + float foutput_min = -std::numeric_limits::infinity(); + float foutput_max = std::numeric_limits::infinity(); + if (op_type_ == OpComputeType::op_compute_type_fp32) { + status = xnn_create_fully_connected_nc_f32( + shape_broadcast[0], // size_t input_channels, + shape_broadcast[1], // size_t output_channels, + shape_broadcast[0], // size_t input_stride, + shape_broadcast[1], // size_t output_stride, + tensor.Data(), // const float* kernel, + nullptr, // const float* bias, + foutput_min, + foutput_max, + flags, + code_cache, + weight_cache, + &p); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_create_fully_connected_nc_f16( + shape_broadcast[0], // size_t input_channels, + shape_broadcast[1], // size_t output_channels, + shape_broadcast[0], // size_t input_stride, + shape_broadcast[1], // size_t output_stride, + tensor.Data(), // const MLFloat16* kernel, + nullptr, // const MLFloat16* bias, + foutput_min, + foutput_max, + flags, + code_cache, + weight_cache, + &p); + } if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_", op_type_str_, " returned ", status); } op0_.reset(p); @@ -118,24 +153,35 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, Status MatMul::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); - pthreadpool_t threadpool = GetThreadPool(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_)); Tensor* y = ctx->Output(0, helper.OutputShape()); - if (y->Shape().Size() == 0) return Status::OK(); - auto* y_data = y->MutableData(); + xnn_status status = xnn_status_success; + + pthreadpool_t threadpool = GetThreadPool(); + if (op_type_ == OpComputeType::op_compute_type_fp32) { + status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_reshape_fully_connected_nc_f16(op0_.get(), a->Shape()[0], threadpool); + } - xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool); if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_", op_type_str_, " returned ", status); + } + + if (op_type_ == OpComputeType::op_compute_type_fp32) { + auto* y_data = y->MutableData(); + status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data(), y_data); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + auto* y_data = y->MutableData(); + status = xnn_setup_fully_connected_nc_f16(op0_.get(), a->Data(), y_data); } - status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data(), y_data); if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_", op_type_str_, " returned ", status); } status = xnn_run_operator(op0_.get(), nullptr); @@ -146,15 +192,18 @@ Status MatMul::Compute(OpKernelContext* ctx) const { } ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 8, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), MatMul); ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 9, 12, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), MatMul); ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), MatMul); } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index b76e42c4d3729..188cc73189af5 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -31,6 +31,9 @@ class MatMul : public XnnpackKernel { BufferUniquePtr packed_b_; AllocatorPtr myAlloc; + OpComputeType op_type_ = OpComputeType::op_compute_type_invalid; + std::string op_type_str_ = ""; + XnnpackOperator op0_ = nullptr; }; diff --git a/onnxruntime/core/providers/xnnpack/math/softmax.cc b/onnxruntime/core/providers/xnnpack/math/softmax.cc index 43e3ac193de5d..15e260889b055 100644 --- a/onnxruntime/core/providers/xnnpack/math/softmax.cc +++ b/onnxruntime/core/providers/xnnpack/math/softmax.cc @@ -6,8 +6,9 @@ #include #include "core/framework/op_kernel.h" -#include "core/providers/cpu/math/softmax_shared.h" #include "core/optimizer/initializer.h" +#include "core/providers/cpu/math/softmax_shared.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { @@ -70,6 +71,7 @@ bool Softmax::IsOnnxNodeSupported(const NodeUnit& node_unit, const auto* x_type = x_arg.TypeAsProto(); if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8)) { break; } @@ -120,14 +122,16 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} { ORT_ENFORCE(GetType(*input_defs[0], x_dtype)); if (x_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { op_type_ = OpComputeType::op_compute_type_fp32; + } else if (x_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + op_type_ = OpComputeType::op_compute_type_fp16; } else if (x_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { op_type_ = OpComputeType::op_compute_type_qu8; } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*input_defs[0]->TypeAsProto())); - ORT_THROW("unsupported Conv in softmax, we have FLOAT|UINT8, but got ", stype); + ORT_THROW("unsupported compute type in softmax, we have FLOAT|FLOAT16|UINT8, but got ", stype); } - if (op_type_ == OpComputeType::op_compute_type_fp32) { + if (op_type_ == OpComputeType::op_compute_type_fp32 || op_type_ == OpComputeType::op_compute_type_fp16) { opset_ = node.SinceVersion(); } else { // Qlinearsoftmax's opset keep 1, we have to parse it by "opset" @@ -176,6 +180,10 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} { xstatus = xnn_create_softmax_nc_f32( 0, // flags, &p); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + xstatus = xnn_create_softmax_nc_f16( + 0, // flags, + &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_", @@ -200,8 +208,13 @@ Status Softmax::Compute(OpKernelContext* ctx) const { // const size_t D = X_shape.SizeFromDimension(axis_); // the step D is 1 xnn_status status = xnn_status_invalid_state; - auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8 - : xnn_reshape_softmax_nc_f32; + auto reshape_fn = xnn_reshape_softmax_nc_f32; + if (op_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_softmax_nc_f16; + } else if (op_type_ == OpComputeType::op_compute_type_qu8) { + reshape_fn = xnn_reshape_softmax_nc_qu8; + } + status = reshape_fn(op0_.get(), channel_dim_, channel_dim_, channel_dim_, N, threadpool); if (status != xnn_status_success) { @@ -211,8 +224,10 @@ Status Softmax::Compute(OpKernelContext* ctx) const { if (op_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_softmax_nc_qu8(op0_.get(), X->Data(), Y->MutableData()); - } else { + } else if (op_type_ == op_compute_type_fp32) { status = xnn_setup_softmax_nc_f32(op0_.get(), X->Data(), Y->MutableData()); + } else if (op_type_ == op_compute_type_fp16) { + status = xnn_setup_softmax_nc_f16(op0_.get(), X->Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -229,15 +244,18 @@ Status Softmax::Compute(OpKernelContext* ctx) const { } ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 10, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Softmax); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 11, 12, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Softmax); ONNX_OPERATOR_KERNEL_EX(Softmax, kOnnxDomain, 13, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Softmax); ONNX_OPERATOR_KERNEL_EX(QLinearSoftmax, kDynamicDomainByCreate, 1, kXnnpackExecutionProvider, diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index b31b5a94899bf..1fc941d9f52f6 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -33,8 +33,8 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, if (pool_attrs.auto_pad == AutoPadType::SAME_UPPER) { flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING; } - float foutput_min = clip_min_max ? clip_min_max->first : -INFINITY; - float foutput_max = clip_min_max ? clip_min_max->second : INFINITY; + float foutput_min = clip_min_max ? clip_min_max->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max ? clip_min_max->second : std::numeric_limits::infinity(); xnn_status status = xnn_status_unsupported_parameter; if (avgpool_type == OpComputeType::op_compute_type_fp32) { status = xnn_create_average_pooling2d_nhwc_f32(input_padding_top, input_padding_right, @@ -42,6 +42,12 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, pooling_height, pooling_width, stride_height, stride_width, foutput_min, foutput_max, flags, &p); + } else if (avgpool_type == OpComputeType::op_compute_type_fp16) { + status = xnn_create_average_pooling2d_nhwc_f16(input_padding_top, input_padding_right, + input_padding_bottom, input_padding_left, + pooling_height, pooling_width, + stride_height, stride_width, + foutput_min, foutput_max, flags, &p); } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { const float output_scale = quant_param[1].first[0]; const uint8_t output_zero_point = quant_param[1].second; @@ -89,6 +95,11 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit, // share the common checks here for fp32 and quant-op const auto& inputs = node_unit.Inputs(); // use do {} while(false) so it's easier to set a breakpoint on the return + static const ComputeTypeSet compute_type_set = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + }; do { if (node_unit.SinceVersion() < 7) { break; @@ -105,8 +116,7 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit, // we only support float and u8 currently const auto* x_type = x_arg.TypeAsProto(); if (x_type == nullptr || - (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8)) { + !IsComputeTypeSupported(x_type->tensor_type().elem_type(), compute_type_set)) { break; } @@ -197,13 +207,12 @@ AveragePool::AveragePool(const OpKernelInfo& info) const auto& input_dtype = X_arg.TypeAsProto()->tensor_type().elem_type(); if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { avgpool_type_ = OpComputeType::op_compute_type_fp32; + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + avgpool_type_ = OpComputeType::op_compute_type_fp16; } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp quant_param = ParseQuantParamForOp(info, input_dtype, 1); avgpool_type_ = OpComputeType::op_compute_type_qu8; - } else { - auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto())); - ORT_THROW("unsupported AveragePool in XnnpackEP, we have FLOAT|UINT8, but got ", stype); } struct xnn_operator* p; auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, @@ -241,9 +250,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = (avgpool_type_ == OpComputeType::op_compute_type_fp32) - ? xnn_reshape_average_pooling2d_nhwc_f32 - : xnn_reshape_average_pooling2d_nhwc_qu8; + auto reshape_fn = xnn_reshape_average_pooling2d_nhwc_f32; + if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_average_pooling2d_nhwc_f16; + } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { + reshape_fn = xnn_reshape_average_pooling2d_nhwc_qu8; + } auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, &workspace_size, &workspace_alignment, @@ -260,7 +272,9 @@ Status AveragePool::Compute(OpKernelContext* context) const { if (avgpool_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(), X.Data(), Y.MutableData()); - + } else if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), workspace.get(), + X.Data(), Y.MutableData()); } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(), X.Data(), Y.MutableData()); @@ -282,25 +296,29 @@ Status AveragePool::Compute(OpKernelContext* context) const { ONNX_OPERATOR_VERSIONED_KERNEL_EX( AveragePool, kMSInternalNHWCDomain, 7, 9, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), AveragePool); ONNX_OPERATOR_VERSIONED_KERNEL_EX( AveragePool, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), AveragePool); ONNX_OPERATOR_VERSIONED_KERNEL_EX( AveragePool, kMSInternalNHWCDomain, 11, 18, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), AveragePool); ONNX_OPERATOR_KERNEL_EX( AveragePool, kMSInternalNHWCDomain, 19, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), AveragePool); ONNX_OPERATOR_KERNEL_EX( diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index b815cc1570c96..4e6b308e28ae5 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -10,8 +10,8 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" -#include "core/providers/xnnpack/xnnpack_init.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { @@ -22,8 +22,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -56,7 +58,6 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, // we can create the kernel now ORT_RETURN_IF_ERROR(CreateKernel()); } - return Status::OK(); } @@ -102,6 +103,8 @@ Status Conv::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_convolution2d_nhwc_qu8; } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_convolution2d_nhwc_f16; } auto status = reshape_fn(op0_.get(), N, H, W, @@ -112,12 +115,14 @@ Status Conv::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_), "returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (conv_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8) { status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); @@ -143,11 +148,17 @@ Status Conv::Compute(OpKernelContext* context) const { } ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", { + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + }), Conv); ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", { + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + }), Conv); ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 2aafc9be7ffd0..458e6000c8d70 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -54,8 +54,8 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, xnn_status status = xnn_status::xnn_status_uninitialized; p = nullptr; - float foutput_min = clip_min_max ? clip_min_max->first : -INFINITY; - float foutput_max = clip_min_max ? clip_min_max->second : INFINITY; + float foutput_min = clip_min_max ? clip_min_max->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max ? clip_min_max->second : std::numeric_limits::infinity(); // with the following IC and OC number, we can cover depthwise and regular conv at the same time // the equation 'IC (group_input_channels) == C ' set up when group_count==1 (regular convolution) // and OC (group_output_channels) follows the same rule. @@ -81,6 +81,24 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, foutput_min, foutput_max, flags, code_cache, weights_cache, &p); + } else if (conv_type == OpComputeType::op_compute_type_fp16) { + const auto* B_data = Bias ? Bias->Data() : nullptr; + auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 + : xnn_create_convolution2d_nhwc_f16; + status = create_func( + input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, + subsampling_height, subsampling_width, + dilation_height, dilation_width, + group_count, + group_input_channels, + group_output_channels, + C, M, // input channel stride, output channel stride + Weight.Data(), B_data, // kernel, bias + foutput_min, foutput_max, + flags, + code_cache, weights_cache, + &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; const int8_t output_zero_point = quant_param[2].second; @@ -236,6 +254,13 @@ OpComputeType GetConvCompType( return op_compute_type_qu8; } break; + case TensorTypeFp16: + if (input_datatype == TensorTypeFp16 && + (!bias_datatype || *bias_datatype == TensorTypeInt32) && + output_datatype == TensorTypeFp16) { + return op_compute_type_fp16; + } + break; default: break; } @@ -326,10 +351,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& // we only support float and u8 currently const auto* x_type = x_arg.TypeAsProto(); - if (x_type == nullptr || - (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) { break; } // require C, H, W to be known so we can construct the xnnpack kernel prior to Compute @@ -420,9 +442,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { weight_index = 3; conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype); + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + conv_type_ = OpComputeType::op_compute_type_fp16; } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto())); - ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype); + ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype); } ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight), @@ -491,7 +515,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) output_shape_.push_back(M_); } - // have to delay creating the xnnpack kernel until after the weights are pre-packed. } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 01c8119fea79d..b6930a5fc92d1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -7,6 +7,7 @@ #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/framework/tensorprotoutils.h" namespace onnxruntime { @@ -18,8 +19,10 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -129,6 +132,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8; } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16; } status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1, @@ -146,6 +151,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data(), Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -161,16 +168,18 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), - ConvTranspose); - ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), + "T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), ConvTranspose); +ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ConvTranspose); + ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider, KernelDefBuilder() .TypeConstraint( diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 749e004094ba1..c828ae9400174 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -3,6 +3,8 @@ #include "max_pool.h" +#include + #include "core/graph/graph.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/xnnpack_init.h" @@ -168,8 +170,8 @@ MaxPool::MaxPool(const OpKernelInfo& info) auto input_dtype = X_arg.TypeAsProto()->tensor_type().elem_type(); xnn_status status = xnn_status_invalid_state; struct xnn_operator* p = nullptr; - float foutput_min = clip_min_max_ ? clip_min_max_->first : -INFINITY; - float foutput_max = clip_min_max_ ? clip_min_max_->second : INFINITY; + float foutput_min = clip_min_max_ ? clip_min_max_->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max_ ? clip_min_max_->second : std::numeric_limits::infinity(); if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { maxpool_type_ = OpComputeType::op_compute_type_fp32; status = xnn_create_max_pooling2d_nhwc_f32(input_padding_top, input_padding_right, @@ -200,14 +202,12 @@ MaxPool::MaxPool(const OpKernelInfo& info) output_min, output_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { maxpool_type_ = OpComputeType::op_compute_type_fp16; - const float output_min = -65504.0; - const float output_max = 65504.0; status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - output_min, output_max, flags, &p); + foutput_min, foutput_max, flags, &p); } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto())); ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8|FLOAT16, but got ", stype); @@ -282,18 +282,21 @@ Status MaxPool::Compute(OpKernelContext* context) const { ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MaxPool); ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MaxPool); ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MaxPool); @@ -301,27 +304,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpa ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionProvider, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MaxPool); -#ifdef XNNPACK_FP16_SUPPORTED -ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, MLFloat16, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MaxPool); - -ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, MLFloat16, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MaxPool); - -ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, MLFloat16, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MaxPool); - -ONNX_OPERATOR_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, MLFloat16, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MaxPool); -#endif - } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index db5648d5d6e54..45c292a36e3fe 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -29,9 +29,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, const auto& x_arg = inputs[0].node_arg; const auto* x_type = x_arg.TypeAsProto(); - if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) { break; } @@ -181,6 +179,9 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: op_type_ = OpComputeType::op_compute_type_fp32; break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + op_type_ = OpComputeType::op_compute_type_fp16; + break; case ONNX_NAMESPACE::TensorProto_DataType_UINT8: op_type_ = OpComputeType::op_compute_type_qu8; break; @@ -189,7 +190,7 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf break; default: auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*input_defs[0]->TypeAsProto())); - ORT_THROW("unsupported op in Resize, we have FLOAT|UINT8|INT8, but get ", stype); + ORT_THROW("unsupported op in Resize, we have FLOAT|FLOAT16|UINT8|INT8, but get ", stype); } const auto* x_shape = input_defs[0]->Shape(); @@ -229,6 +230,8 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + xstatus = xnn_create_resize_bilinear2d_nhwc_f16(out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); } else { @@ -261,7 +264,9 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, std::unique_ptr workspace(nullptr, deallocator); auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32; - if (op_type_ == OpComputeType::op_compute_type_qu8) { + if (op_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f16; + } else if (op_type_ == OpComputeType::op_compute_type_qu8) { reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8; } else if (op_type_ == OpComputeType::op_compute_type_qs8) { reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; @@ -279,6 +284,9 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, if (op_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data(), output->MutableData()); + } else if (op_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_resize_bilinear2d_nhwc_f16(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data(), output->MutableData()); @@ -327,22 +335,26 @@ Status Resize::Compute(OpKernelContext* ctx) const { ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 11, 12, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 13, 17, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 18, 18, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index df7df0b4376ce..ee4e7be0f1f49 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -31,10 +31,6 @@ KernelCreateInfo BuildKernelCreateInfo() { BuildKernelCreateInfo< \ ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)> -#define KERNEL_CREATE_INFO_VERSIONED_TYPED(Start, End, Type, Op, Domain) \ - BuildKernelCreateInfo< \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Type, Op)> - #define KERNEL_CREATE_INFO(Start, Op, Domain) \ BuildKernelCreateInfo< \ ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)> @@ -43,19 +39,6 @@ KernelCreateInfo BuildKernelCreateInfo() { BuildKernelCreateInfo< \ ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)> -#ifdef XNNPACK_FP16_SUPPORTED -#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \ - startver, endver, MLFloat16, name) - -#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \ - MLFloat16, name) -#else -#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) -#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) -#endif - // Layout sensitive operators in NHWC domain class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); @@ -85,10 +68,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSIn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); -CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); -CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); -CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); -CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); // ONNX operators class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm); @@ -159,13 +138,6 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain), KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate), - -#ifdef XNNPACK_FP16_SUPPORTED - KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain), - KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain), - KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain), - KERNEL_CREATE_INFO_TYPED(12, MLFloat16, MaxPool, kMSInternalNHWCDomain), -#endif }; for (auto& function_table_entry : function_table) { @@ -286,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> capabilities; std::shared_ptr registry = GetKernelRegistry(); @@ -296,7 +269,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_init.h b/onnxruntime/core/providers/xnnpack/xnnpack_init.h index ed824939a40da..89e92d0b99b13 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_init.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_init.h @@ -46,15 +46,20 @@ namespace xnnpack { #define XNN_ALLOCATION_ALIGNMENT 16 #endif +#if defined(__arm__) || defined(_M_ARM) +#define XNN_ARCH_ARM 1 +#else +#define XNN_ARCH_ARM 0 +#endif + #if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) #define XNN_ARCH_ARM64 1 #else #define XNN_ARCH_ARM64 0 #endif -// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now. -// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels. -#if XNN_ARCH_ARM64 +// referenced from xnn_is_f16_compatible_config in XNNPACK/src/xnnpack/hardware-config.h +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 || ((XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE) #define XNNPACK_FP16_SUPPORTED #endif diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index e5e718fb8d1de..48213e3e3894a 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -117,19 +117,9 @@ Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, co } // determine if arena should be used - const bool create_arena = [&]() -> bool { -#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) - // We use these allocators instead of the arena - return false; -#else - // Disable Arena allocator for 32-bit builds because it may run into infinite loop when integer overflow happens - if constexpr (sizeof(void*) == 4) { - return false; - } else { - return mem_info.alloc_type == OrtArenaAllocator; - } -#endif - }(); + const bool create_arena = DoesCpuAllocatorSupportArenaUsage() + ? (mem_info.alloc_type == OrtArenaAllocator) + : false; AllocatorPtr allocator_ptr; // create appropriate DeviceAllocatorRegistrationInfo and allocator based on create_arena diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 76d34aabab6cb..a60ee500a9898 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -249,7 +249,7 @@ Status GetMinimalBuildOptimizationHandling( std::atomic InferenceSession::global_session_id_{1}; std::map InferenceSession::active_sessions_; #ifdef _WIN32 -OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +std::mutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ onnxruntime::WindowsTelemetry::EtwInternalCallback InferenceSession::callback_ML_ORT_provider_; #endif @@ -370,86 +370,12 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // a monotonically increasing session id for use in telemetry session_id_ = global_session_id_.fetch_add(1); -#ifdef _WIN32 - std::lock_guard lock(active_sessions_mutex_); - active_sessions_[global_session_id_++] = this; - - // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider - callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( - [this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - // Check if this callback is for capturing state - if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && - ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { - LogAllSessions(); - } - }); - WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); - - // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - if (logging_manager_ != nullptr) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && - IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - logging_manager_->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - } - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; - logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); - LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; - } - } - }); - - // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); - -#endif - SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options, false); + TraceSessionOptions(session_options, false, *session_logger_); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -575,14 +501,97 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } telemetry_ = {}; + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[session_id_] = this; + + // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider + callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + InferenceSession::LogAllSessions(); + } + }); + WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); + + // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (logging_manager_ != nullptr) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && + IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + logging_manager_->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + } + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; + logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); + LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; + } + } + }); + + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + +#endif } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger) { ORT_UNUSED_PARAMETER(captureState); // Otherwise Linux build error - LOGS(*session_logger_, INFO) << session_options; + LOGS(logger, INFO) << session_options; #ifdef _WIN32 + std::string optimized_model_filepath = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath); + std::string profile_file_prefix = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix); + TraceLoggingWrite(telemetry_provider_handle, "SessionOptions", TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), @@ -590,11 +599,11 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.execution_mode), "execution_mode"), TraceLoggingUInt8(static_cast(session_options.execution_order), "execution_order"), TraceLoggingBoolean(session_options.enable_profiling, "enable_profiling"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath).c_str(), "optimized_model_filepath"), + TraceLoggingString(optimized_model_filepath.c_str(), "optimized_model_filepath"), TraceLoggingBoolean(session_options.enable_mem_pattern, "enable_mem_pattern"), TraceLoggingBoolean(session_options.enable_mem_reuse, "enable_mem_reuse"), TraceLoggingBoolean(session_options.enable_cpu_mem_arena, "enable_cpu_mem_arena"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix).c_str(), "profile_file_prefix"), + TraceLoggingString(profile_file_prefix.c_str(), "profile_file_prefix"), TraceLoggingString(session_options.session_logid.c_str(), "session_logid"), TraceLoggingInt8(static_cast(session_options.session_log_severity_level), "session_log_severity_level"), TraceLoggingInt8(static_cast(session_options.session_log_verbosity_level), "session_log_verbosity_level"), @@ -725,11 +734,15 @@ InferenceSession::~InferenceSession() { // Unregister the session and ETW callbacks #ifdef _WIN32 - std::lock_guard lock(active_sessions_mutex_); - WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + std::lock_guard lock(active_sessions_mutex_); + if (callback_ML_ORT_provider_ != nullptr) { + WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); + } + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif - active_sessions_.erase(global_session_id_); + active_sessions_.erase(session_id_); #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) @@ -745,7 +758,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for exec provider"); } - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (is_inited_) { // adding an EP is pointless as the graph as already been partitioned so no nodes will be assigned to @@ -876,7 +889,7 @@ common::Status InferenceSession::RegisterGraphTransformer( return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer"); } - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (is_inited_) { // adding a transformer now is pointless as the graph as already been transformed @@ -940,7 +953,7 @@ common::Status InferenceSession::LoadWithLoader(std::function l(session_mutex_); + std::lock_guard l(session_mutex_); if (is_model_loaded_) { // already loaded LOGS(*session_logger_, ERROR) << "This session already contains a loaded model."; return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); @@ -1396,7 +1409,7 @@ Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len } Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort_format_model_bytes) { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (is_model_loaded_) { // already loaded Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); @@ -1520,7 +1533,7 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort } bool InferenceSession::IsInitialized() const { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); return is_inited_; } @@ -1631,7 +1644,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { @@ -1653,6 +1666,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { } } } + +// This function is called when the session is being initialized. +// For now, this function only checks for invalid combination of DML EP with other EPs. +// TODO: extend this function to check for other invalid combinations of EPs. +common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const { + // DML EP is only allowed with CPU EP + bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr; + if (has_dml_ep) { + const auto& ep_list = execution_providers_.GetIds(); + for (const auto& ep : ep_list) { + if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue; + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP."); + } + } + return Status::OK(); +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong. @@ -1673,7 +1703,7 @@ common::Status InferenceSession::Initialize() { bool have_cpu_ep = false; { - std::lock_guard initial_guard(session_mutex_); + std::lock_guard initial_guard(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; @@ -1710,8 +1740,13 @@ common::Status InferenceSession::Initialize() { execution_providers_.SetCpuProviderWasImplicitlyAdded(true); } + // Check for the presence of an invalid combination of execution providers in the session + // For e.g. we don't support DML EP and other GPU EPs to be present in the same session + // This check is placed here because it serves as a common place for all language bindings. + ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders()); + // re-acquire mutex - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); #if !defined(DISABLE_EXTERNAL_INITIALIZERS) && !defined(ORT_MINIMAL_BUILD) if (!session_options_.external_initializers.empty()) { @@ -1805,7 +1840,8 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, minimal_build_optimization_handling, - record_runtime_optimization_produced_op_schema)); + record_runtime_optimization_produced_op_schema, + *session_logger_)); #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -2077,7 +2113,7 @@ common::Status InferenceSession::Initialize() { std::vector tuning_results; bool found_tuning_results = false; ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( - model_metadata_, tuning_results, found_tuning_results)); + model_metadata_, tuning_results, found_tuning_results, *session_logger_)); if (found_tuning_results) { ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/ false, /*auto_enable*/ true)); } @@ -2475,6 +2511,24 @@ struct ThreadPoolSpinningSwitch { }; } // namespace +Status InferenceSession::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + Status retval = Status::OK(); + + if (!is_inited_) { + LOGS(*session_logger_, ERROR) << "Session was not initialized"; + return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); + } + + // TODO: only call SetEpDynamicOptions for all providers in-use + for (auto& xp : execution_providers_) { + auto status = xp->SetEpDynamicOptions(keys, values); + ORT_CHECK_AND_SET_RETVAL(status); + } + + return retval; +} + Status InferenceSession::Run(const RunOptions& run_options, gsl::span feed_names, gsl::span feeds, gsl::span output_names, std::vector* p_fetches, @@ -2566,7 +2620,7 @@ Status InferenceSession::Run(const RunOptions& run_options, std::unique_ptr owned_run_logger; const auto& run_logger = CreateLoggerForRun(run_options, owned_run_logger); - std::optional> sequential_run_lock; + std::optional> sequential_run_lock; if (is_concurrent_run_supported_ == false) { sequential_run_lock.emplace(session_mutex_); } @@ -2819,7 +2873,7 @@ common::Status InferenceSession::Run(const RunOptions& run_options, const NameML std::pair InferenceSession::GetModelMetadata() const { { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); @@ -2831,7 +2885,7 @@ std::pair InferenceSession::GetModelMetada std::pair InferenceSession::GetModelInputs() const { { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); @@ -2844,7 +2898,7 @@ std::pair InferenceSession::GetModelInputs( std::pair InferenceSession::GetOverridableInitializers() const { { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); @@ -2857,7 +2911,7 @@ std::pair InferenceSession::GetOverridableI std::pair InferenceSession::GetModelOutputs() const { { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); @@ -2869,7 +2923,7 @@ std::pair InferenceSession::GetModelOutput common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { - std::lock_guard l(session_mutex_); + std::lock_guard l(session_mutex_); if (!is_inited_) { LOGS(*session_logger_, ERROR) << "Session was not initialized"; return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); @@ -3180,7 +3234,8 @@ common::Status InferenceSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); @@ -3192,7 +3247,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3204,6 +3259,7 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, + logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3253,18 +3309,25 @@ IOBinding* SessionIOBinding::Get() { void InferenceSession::LogAllSessions() { const Env& env = Env::Default(); - std::lock_guard lock(active_sessions_mutex_); + std::lock_guard lock(active_sessions_mutex_); for (const auto& session_pair : active_sessions_) { InferenceSession* session = session_pair.second; - onnxruntime::Graph& graph = model_->MainGraph(); - bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); - env.GetTelemetryProvider().LogSessionCreation( - session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, true); + if (!session) { + continue; + } + + auto model = session->model_; + if (nullptr != model) { + onnxruntime::Graph& graph = model->MainGraph(); + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + env.GetTelemetryProvider().LogSessionCreation( + session->session_id_, model->IrVersion(), model->ProducerName(), model->ProducerVersion(), model->Domain(), + graph.DomainToVersionMap(), graph.Name(), model->MetaData(), + session->telemetry_.event_name_, session->execution_providers_.GetIds(), model_has_fp16_inputs, true); + } - TraceSessionOptions(session->session_options_, true); + InferenceSession::TraceSessionOptions(session->session_options_, true, *session->session_logger_); } } #endif diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8c22fac4dd0c5..e28ff75345785 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -29,7 +29,7 @@ #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/platform/ort_mutex.h" +#include #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -129,7 +129,7 @@ class InferenceSession { using InputOutputDefMetaMap = InlinedHashMap; static std::map active_sessions_; #ifdef _WIN32 - static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ + static std::mutex active_sessions_mutex_; // Protects access to active_sessions_ static onnxruntime::WindowsTelemetry::EtwInternalCallback callback_ML_ORT_provider_; onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; #endif @@ -330,6 +330,9 @@ class InferenceSession { */ [[nodiscard]] common::Status Initialize(); + [[nodiscard]] common::Status SetEpDynamicOptions(gsl::span keys, + gsl::span values); + [[nodiscard]] common::Status Run(const RunOptions& run_options, gsl::span feed_names, gsl::span feeds, gsl::span output_names, std::vector* p_fetches, @@ -617,7 +620,7 @@ class InferenceSession { const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); - + [[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const; [[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model); #if !defined(ORT_MINIMAL_BUILD) @@ -660,7 +663,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options, bool captureState); + static void TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -687,8 +690,9 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, - /*out*/ InlinedVector& arenas_to_shrink) const; + [[nodiscard]] common::Status ValidateAndParseShrinkArenaString( + const std::string& ort_device_list, + /*out*/ InlinedVector& arenas_to_shrink) const; /* * Performs the shrinkage of arenas requested to be shrunk by the user @@ -697,7 +701,7 @@ class InferenceSession { void ShrinkMemoryArenas(gsl::span arenas_to_shrink); #ifdef _WIN32 - void LogAllSessions(); + static void LogAllSessions(); #endif #if !defined(ORT_MINIMAL_BUILD) @@ -705,7 +709,8 @@ class InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const; common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); @@ -796,10 +801,10 @@ class InferenceSession { // Number of concurrently running executors std::atomic current_num_runs_ = 0; - mutable onnxruntime::OrtMutex session_mutex_; // to ensure only one thread can invoke Load/Initialize - bool is_model_loaded_ = false; // GUARDED_BY(session_mutex_) - bool is_inited_ = false; // GUARDED_BY(session_mutex_) - bool is_concurrent_run_supported_ = true; // Graph execution in Run is GUARDED_BY(session_mutex_) if false + mutable std::mutex session_mutex_; // to ensure only one thread can invoke Load/Initialize + bool is_model_loaded_ = false; // GUARDED_BY(session_mutex_) + bool is_inited_ = false; // GUARDED_BY(session_mutex_) + bool is_concurrent_run_supported_ = true; // Graph execution in Run is GUARDED_BY(session_mutex_) if false #ifdef ENABLE_LANGUAGE_INTEROP_OPS InterOpDomains interop_domains_; diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 3436eebda3819..8b9de0c604441 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -236,7 +236,8 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, std::vector& results, - bool& key_found) { + bool& key_found, + const logging::Logger& logger) { results.clear(); key_found = false; auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); @@ -245,7 +246,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met } key_found = true; - LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model"; + LOGS(logger, INFO) << "Found tuning results in the model file to be used while loading the model"; Status status; ORT_TRY { diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index a0bcdb9013bf0..f297d928f8a0d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -19,7 +19,9 @@ using json = nlohmann::json; #endif namespace onnxruntime { - +namespace logging { +class Logger; +} namespace inference_session_utils { // need this value to be accessible in all builds in order to report error for attempted usage in a minimal build @@ -60,7 +62,8 @@ class JsonConfigParser { Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, /*out*/ std::vector& results, - /*out*/ bool& key_found); + /*out*/ bool& key_found, + const logging::Logger& logger); #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 64546e634694f..ca6950af0227a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -36,7 +36,7 @@ #include "core/framework/data_types.h" #include "abi_session_options_impl.h" #include "core/framework/TensorSeq.h" -#include "core/platform/ort_mutex.h" +#include #include "core/common/string_helper.h" #include "core/session/lora_adapters.h" @@ -843,6 +843,28 @@ void CheckAndAdjustInputSpansForLora(const OrtRunOptions& run_options, } // namespace +ORT_API_STATUS_IMPL(OrtApis::SetEpDynamicOptions, _Inout_ OrtSession* sess, + _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, + _In_ size_t kv_len) { + API_IMPL_BEGIN + auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); + + auto keys_span = gsl::make_span(keys, kv_len); + auto values_span = gsl::make_span(values, kv_len); + + Status status; + + if (kv_len == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no imputs were passed"); + } else { + status = session->SetEpDynamicOptions(keys_span, + values_span); + } + return ToOrtStatus(status); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, _In_reads_(input_len) const char* const* input_names, _In_reads_(input_len) const OrtValue* const* input, size_t input_len, @@ -2447,7 +2469,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_20 = { +static constexpr OrtApi ort_api_1_to_21 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2781,10 +2803,15 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::CreateLoraAdapter, &OrtApis::CreateLoraAdapterFromArray, &OrtApis::ReleaseLoraAdapter, &OrtApis::RunOptionsAddActiveLoraAdapter, + + &OrtApis::SetEpDynamicOptions, + // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2816,18 +2843,20 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2) / sizeof(void*) == 275, "Size of version 17 API cannot change"); static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); +// no additions in version 19 +static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.20.0", +static_assert(std::string_view(ORT_VERSION) == "1.21.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it -// 2. If there were any APIs added to ort_api_1_to_20 above: +// 2. If there were any APIs added to ort_api_1_to_21 above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_20; + return &ort_api_1_to_21; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9054246873232..52d3c98d526dc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -531,4 +531,6 @@ ORT_API_STATUS_IMPL(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t n ORT_API(void, ReleaseLoraAdapter, _Frees_ptr_opt_ OrtLoraAdapter*); ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); +ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 3c178fd1e91d3..335ebbf203e7c 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -17,15 +17,27 @@ using namespace onnxruntime; using namespace onnxruntime::logging; +#ifdef USE_WEBGPU +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +} // namespace webgpu +} // namespace onnxruntime +#endif + std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; -onnxruntime::OrtMutex OrtEnv::m_; +std::mutex OrtEnv::m_; OrtEnv::OrtEnv(std::unique_ptr value1) : value_(std::move(value1)) { } OrtEnv::~OrtEnv() { +#ifdef USE_WEBGPU + webgpu::CleanupWebGpuContexts(); +#endif + // We don't support any shared providers in the minimal build yet #if !defined(ORT_MINIMAL_BUILD) UnloadSharedProviders(); @@ -35,7 +47,7 @@ OrtEnv::~OrtEnv() { OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status, const OrtThreadingOptions* tp_options) { - std::lock_guard lock(m_); + std::lock_guard lock(m_); if (!p_instance_) { std::unique_ptr lmgr; std::string name = lm_info.logid; @@ -76,7 +88,7 @@ void OrtEnv::Release(OrtEnv* env_ptr) { if (!env_ptr) { return; } - std::lock_guard lock(m_); + std::lock_guard lock(m_); ORT_ENFORCE(env_ptr == p_instance_.get()); // sanity check --ref_count_; if (ref_count_ == 0) { diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 444134d0612e9..64e0020f2930d 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -5,7 +5,7 @@ #include #include #include "core/session/onnxruntime_c_api.h" -#include "core/platform/ort_mutex.h" +#include #include "core/common/status.h" #include "core/common/logging/logging.h" #include "core/framework/allocator.h" @@ -67,7 +67,7 @@ struct OrtEnv { private: static std::unique_ptr p_instance_; - static onnxruntime::OrtMutex m_; + static std::mutex m_; static int ref_count_; std::unique_ptr value_; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 85079ef78c8d3..1444c1976d447 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -41,6 +41,7 @@ #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" +#include #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -278,8 +279,9 @@ struct ProviderHostImpl : ProviderHost { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) override { - return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) override { + return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } @@ -495,6 +497,7 @@ struct ProviderHostImpl : ProviderHost { void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) override { return p->set_type(value); } ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) override { return p->add_tensors(); } + std::string* AttributeProto__release_s(ONNX_NAMESPACE::AttributeProto* p) override { return p->release_s(); } // GraphProto (wrapped) std::unique_ptr GraphProto__construct() override { return std::make_unique(); } @@ -706,6 +709,12 @@ struct ProviderHostImpl : ProviderHost { return p->GetConfigEntry(config_key); } + // ConfigOptions (wrapped) + std::string ConfigOptions__GetConfigOrDefault(const ConfigOptions* p, const std::string& config_key, + const std::string& default_value) override { + return p->GetConfigOrDefault(config_key, default_value); + } + // OrtRunOptions (wrapped) const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) override { return p->config_options; } @@ -1050,8 +1059,8 @@ struct ProviderHostImpl : ProviderHost { } std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) override { - return QDQ::GetAllNodeUnits(*graph_viewer); + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) override { + return QDQ::GetAllNodeUnits(*graph_viewer, logger); } // Model (wrapped) @@ -1149,8 +1158,8 @@ struct ProviderHostImpl : ProviderHost { // GraphViewer (wrapped) void GraphViewer__operator_delete(GraphViewer* p) override { delete p; } - std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger) override { - return std::make_unique(graph_viewer->Name(), true, ModelMetaData(), PathString(), + std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) override { + return std::make_unique(graph_viewer->Name(), true, metadata, PathString(), #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), #else @@ -1205,6 +1214,7 @@ struct ProviderHostImpl : ProviderHost { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1783,12 +1793,6 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - if (legacy_ov_options->enable_npu_fast_compile) { - ov_options_converted_map["enable_npu_fast_compile"] = "false"; - } else { - ov_options_converted_map["enable_npu_fast_compile"] = "true"; - } - if (legacy_ov_options->num_of_threads != '\0') ov_options_converted_map["num_of_threads"] = std::to_string(legacy_ov_options->num_of_threads); @@ -1809,51 +1813,24 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["disable_dynamic_shapes"] = "true"; } + if (legacy_ov_options->enable_npu_fast_compile) { + LOGS_DEFAULT(WARNING) << "enable_npu_fast_compile option is deprecated. Skipping this option"; + } // Add new provider option below ov_options_converted_map["num_streams"] = "1"; - ov_options_converted_map["export_ep_ctx_blob"] = "false"; + ov_options_converted_map["load_config"] = ""; ov_options_converted_map["model_priority"] = "DEFAULT"; ov_options_converted_map["enable_qdq_optimizer"] = "false"; return ov_options_converted_map; } -std::shared_ptr OpenVINOProviderFactoryCreator::Create(const OrtOpenVINOProviderOptions* provider_options) { - ProviderOptions ov_options_converted_map = onnxruntime::OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(provider_options); - return s_library_openvino.Get().CreateExecutionProviderFactory(&ov_options_converted_map); -} - -void ORTSessionOptionsToOrtOpenVINOProviderOptions(ProviderOptions& ov_options, - const SessionOptions* session_options) { - bool disable_cpu_fallback = session_options->config_options.GetConfigOrDefault( - kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; - if (disable_cpu_fallback) - ov_options["disable_cpu_fallback"] = "true"; - - // values from session options will override the providerOptions Value - bool so_epctx_enable = session_options->config_options.GetConfigOrDefault( - kOrtSessionOptionEpContextEnable, "0") == "1"; - if (so_epctx_enable) - ov_options["so_export_ep_ctx_blob"] = "true"; - - std::string so_cache_path = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "").c_str(); - ov_options["so_epctx_path"] = so_cache_path; - - // Default embedMode is 1. Saving the compiled model contents as a Epctx node attribute - bool so_epctx_embed_mode = session_options->config_options.GetConfigOrDefault( - kOrtSessionOptionEpContextEmbedMode, "1") == "0"; - if (so_epctx_embed_mode) { - // defaults to true - ov_options["so_epctx_embed_mode"] = "false"; - } -} - -std::shared_ptr OpenVINOProviderFactoryCreator::Create(ProviderOptions* provider_options_map, - const SessionOptions* session_options) { +std::shared_ptr OpenVINOProviderFactoryCreator::Create( + const ProviderOptions* provider_options_map, const SessionOptions* session_options) { // Append session options applicable for EP to EP Provider options. - if (session_options) { - onnxruntime::ORTSessionOptionsToOrtOpenVINOProviderOptions(*provider_options_map, session_options); - } - return s_library_openvino.Get().CreateExecutionProviderFactory(provider_options_map); + std::pair config_buffer = {provider_options_map, + session_options->config_options}; + const void* obj = reinterpret_cast(&config_buffer); + return s_library_openvino.Get().CreateExecutionProviderFactory(obj); } std::shared_ptr DnnlProviderFactoryCreator::Create(const OrtDnnlProviderOptions* dnnl_options) { @@ -2106,9 +2083,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX, _In API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) { +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, + _In_ const OrtOpenVINOProviderOptions* provider_options) { API_IMPL_BEGIN - auto factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(provider_options); + const onnxruntime::ProviderOptions ov_options_converted_map = onnxruntime::OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(provider_options); + auto factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(&ov_options_converted_map, &(options->value)); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_OpenVINO: Failed to load shared library"); } @@ -2264,7 +2243,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, new_tensorrt_options.trt_ep_context_file_path = (context_cache_path.size() == 0) ? nullptr : context_cache_path.c_str(); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path; - embed_mode = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + embed_mode = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0"); if ("1" == embed_mode) { new_tensorrt_options.trt_ep_context_embed_mode = 1; } else if ("0" == embed_mode) { diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8c512c561ea8c..7fb518cdc05ca 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "VitisAI") == 0) { +#ifdef USE_VITISAI status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "CoreML") == 0) { +#if defined(USE_COREML) + options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); +#endif } else { ORT_UNUSED_PARAMETER(options); status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'"); + "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'"); } return status; @@ -205,15 +215,6 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, } #endif -#ifndef USE_TVM -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tvm, - _In_ OrtSessionOptions* options, _In_ const char* settings) { - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(settings); - return CreateNotEnabledStatus("Tvm"); -} -#endif - #ifdef __cplusplus } #endif diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 9cbf01946e92b..2706448d831cc 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -314,7 +314,8 @@ class StandAloneKernelContext : public OpKernelContext { AllocatorPtr allocator_; }; // StandAloneKernelContext -onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, OrtOpAttr** op_attr) { +onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, + OrtOpAttr** op_attr) { auto attr = std::make_unique(); onnxruntime::Status status = onnxruntime::Status::OK(); attr->set_name(std::string{name}); @@ -410,7 +411,9 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info, node_ptr->SetSinceVersion(version); - auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, &kernel_create_info); + auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, + logging::LoggingManager::DefaultLogger(), // no other logger available + &kernel_create_info); ORT_RETURN_IF_ERROR(status); auto& kernel_def = kernel_create_info->kernel_def; diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc index 9a6f8f3b7b1c8..7986082da06f7 100644 --- a/onnxruntime/lora/adapter_format_utils.cc +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -6,6 +6,8 @@ #include "core/framework/allocator.h" #include "core/common/common.h" +#include "core/framework/endian.h" +#include "core/framework/endian_utils.h" #include "core/common/span_utils.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" @@ -75,17 +77,75 @@ const Adapter* ValidateAndGetAdapterFromBytes(gsl::span bytes) { return adapter; } +template +struct WriteDataForLittleEndian { + Status operator()(gsl::span src, gsl::span dest) const { + auto src_span = ReinterpretAsSpan(src); + return onnxruntime::utils::WriteLittleEndian(src_span, dest); + } +}; + void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name, TensorDataType data_type, gsl::span shape, gsl::span data, flatbuffers::Offset& fbs_tensor) { auto name_str = (name.empty()) ? 0 : flat_builder.CreateString(name.data(), name.size()); auto shape_vec = flat_builder.CreateVector(shape.data(), shape.size()); - auto data_vec = flat_builder.CreateVector(data.data(), data.size()); + flatbuffers::Offset> data_vec; + if constexpr (endian::native == endian::big) { + const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); + if (elem_type->Size() > 1) { + InlinedVector be_data(data.size()); + auto be_data_span = ReinterpretAsSpan(AsSpan(be_data)); + + onnxruntime::utils::MLTypeCallDispatcher + disp(static_cast(data_type)); + + ORT_THROW_IF_ERROR((disp.InvokeRet(data, be_data_span))); + data_vec = flat_builder.CreateVector(be_data.data(), be_data.size()); + } else { + data_vec = flat_builder.CreateVector(data.data(), data.size()); + } + } else { + data_vec = flat_builder.CreateVector(data.data(), data.size()); + } fbs_tensor = CreateParameter(flat_builder, name_str, shape_vec, data_type, data_vec); } +template +struct ReadDataForBigEndian { + Status operator()(gsl::span src, Tensor& dst) const { + auto dst_span = dst.MutableDataAsSpan(); + return onnxruntime::utils::ReadLittleEndian(src, dst_span); + } +}; + +// If BE, we a allocate memory within the tensor and copy there swapping bytes +[[maybe_unused]] static Status CreateOrtValueForBePlatforms(const Parameter& param, const MLDataType elem_type, + gsl::span shape, OrtValue& result) { + static const AllocatorPtr cpu_allocator = std::make_shared(); + + auto src_span = ReinterpretAsSpan( + gsl::make_span(param.raw_data()->data(), param.raw_data()->size())); + + const auto data_type = param.data_type(); + + Tensor tensor(elem_type, shape, cpu_allocator); + onnxruntime::utils::MLTypeCallDispatcher + disp(static_cast(data_type)); + + ORT_RETURN_IF_ERROR((disp.InvokeRet(src_span, tensor))); + Tensor::InitOrtValue(std::move(tensor), result); + return Status::OK(); +} + std::pair CreateOrtValueOverLoraParameter(const Parameter& param) { OrtValue result; @@ -93,17 +153,32 @@ std::pair CreateOrtValueOverLoraParameter(const Parameter LoadStringFromLoraFormat(name, param.name()); const auto data_type = param.data_type(); - gsl::span shape_span(param.dims()->data(), param.dims()->size()); - + // Copying shape takes care of endianess using flatbuffers accessors + TensorShapeVector shape(param.dims()->begin(), param.dims()->end()); + const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator); - auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); - // const_cast is necessery due to Tensor class API - Tensor::InitOrtValue(elem_type, - TensorShape(shape_span), - const_cast(param.raw_data()->data()), - cpu_meminfo, - result); + if constexpr (endian::native == endian::big) { + if (elem_type->Size() > 1) { + ORT_THROW_IF_ERROR(CreateOrtValueForBePlatforms(param, elem_type, shape, result)); + } else { + // Single byte elements allow us to create OrtValue directly on top + // of raw data + // const_cast is necessary due to Tensor class API + Tensor::InitOrtValue(elem_type, + TensorShape(shape), + const_cast(param.raw_data()->data()), + cpu_meminfo, + result); + } + } else { + // const_cast is necessary due to Tensor class API + Tensor::InitOrtValue(elem_type, + TensorShape(shape), + const_cast(param.raw_data()->data()), + cpu_meminfo, + result); + } return std::make_pair(std::move(name), std::move(result)); } diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 97b7358f2a223..67423fe9b5a33 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -87,7 +87,7 @@ def supports_device(cls, device): """ if device == "CUDA": device = "GPU" - return device in get_device() + return "-" + device in get_device() or device + "-" in get_device() or device == get_device() @classmethod def prepare(cls, model, device=None, **kwargs): diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f4f10dc4b4b97..d05fba192820a 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -26,6 +26,8 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: return C.OrtDevice.cpu() elif device_type == "dml": return C.OrtDevice.dml() + elif device_type == "webgpu": + return C.OrtDevice.webgpu() elif device_type == "ort": return C.get_ort_device(device_index).device_type() else: diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 084ee6bc50698..5742b4db42512 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -280,7 +280,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { uint32_t readback_heap_size = gsl::narrow_cast(sizeof(readback_heap)); ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(dml_readback_heap_guid, &readback_heap_size, &readback_heap)); - // ReadbackFromGpu already syncs with the CPU and waits for the copy to be completed, so we don't need to sync after + // ReadbackFromGpu already syncs with the CPU and waits for the copy to be completed, so we dont need to sync after // this call readback_heap->ReadbackFromGpu( gsl::make_span(static_cast(dst), num_bytes), @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetDmlToHostMemCpyFunction() { static std::unordered_map map{ - {OrtDevice::GPU, DmlToCpuMemCpy}}; + {OrtDevice::DML, DmlToCpuMemCpy}}; return ↦ } @@ -428,7 +428,7 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type) { // Special, not a C type expands to enum value of 16 {NPY_FLOAT16, DataTypeImpl::GetType()}, {NPY_DOUBLE, DataTypeImpl::GetType()}, - // We don't want to use size specific types such + // We dont want to use size specific types such // as NPY_INT32 bc they are not enums but hash defines // which may map into other enums and may conflict with other entries here // also NPY docs define these sizes as platform specific, thus we @@ -581,6 +581,7 @@ static void CopyDataToTensor(PyArrayObject* darray, int npy_type, Tensor& tensor for (int i = 0; i < total_items; ++i, src += item_size) { // Python unicode strings are assumed to be USC-4. Strings are stored as UTF-8. PyObject* item = PyArray_GETITEM(darray, src); + UniqueDecRefPtr itemGuard(item, DecRefFn()); PyObject* pStr = PyObject_Str(item); UniqueDecRefPtr strGuard(pStr, DecRefFn()); dst[i] = py::reinterpret_borrow(pStr); diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 18785cd607eaa..6a57fc5f900ae 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) { // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); -#elif USE_DML - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML - CreateGenericMLValue( - nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); +#else + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); #endif } else if (device.Type() == OrtDevice::NPU) { #ifdef USE_CANN @@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) { CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCannMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CANN device using this package of OnnxRuntime. " - "Please use the CANN package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CANN device using this package of OnnxRuntime. " + "Please use the CANN package of OnnxRuntime to use this feature."); #endif } else { throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); @@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) { } onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToRocmMemCpy); -#elif USE_DML + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToRocmMemCpy); +#else + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToDmlMemCpy); + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); #else - throw std::runtime_error( - "Unsupported GPU device: Cannot find the supported GPU device."); + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); #endif } else { throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device"); diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 1319e8f6fe959..958da26f4faf0 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -69,11 +69,14 @@ void addGlobalSchemaFunctions(pybind11::module& m) { #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), #endif +#ifdef USE_VSINPU + onnxruntime::VSINPUProviderFactoryCreator::Create(), +#endif #ifdef USE_RKNPU onnxruntime::RknpuProviderFactoryCreator::Create(), #endif #ifdef USE_COREML - onnxruntime::CoreMLProviderFactoryCreator::Create(0), + onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}), #endif #ifdef USE_XNNPACK onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 3062738eefcf2..9d544c0cee9ed 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: -#ifdef USE_DML - return DML; -#else return CUDA; -#endif + case OrtDevice::DML: + return DML; case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: @@ -1062,12 +1060,6 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (option.first == "precision") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "enable_npu_fast_compile") { - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_npu_fast_compile: ", option.second); - } - OV_provider_options_map[option.first] = option.second; } else if (option.first == "enable_opencl_throttling") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { @@ -1103,13 +1095,13 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (option.first == "num_streams") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "cache_dir") { + } else if (option.first == "load_config") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "context") { + } else if (option.first == "cache_dir") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "export_ep_ctx_blob") { + } else if (option.first == "context") { OV_provider_options_map[option.first] = option.second; continue; } else if (option.first == "enable_qdq_optimizer") { @@ -1133,16 +1125,6 @@ std::unique_ptr CreateExecutionProviderInstance( LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met."; } } -#endif - } else if (type == kTvmExecutionProvider) { -#if USE_TVM - onnxruntime::tvm::TvmEPOptions info{}; - const auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - info = onnxruntime::tvm::TvmEPOptionsHelper::FromProviderOptions(it->second); - } - - return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kVitisAIExecutionProvider) { #ifdef USE_VITISAI @@ -1198,6 +1180,10 @@ std::unique_ptr CreateExecutionProviderInstance( const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider(); +#endif + } else if (type == kVSINPUExecutionProvider) { +#ifdef USE_VSINPU + return onnxruntime::VSINPUProviderFactoryCreator::Create()->CreateProvider(); #endif } else if (type == kRknpuExecutionProvider) { #ifdef USE_RKNPU @@ -1219,6 +1205,8 @@ std::unique_ptr CreateExecutionProviderInstance( if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) { coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY; + } else if (flags_str.find("COREML_FLAG_USE_CPU_AND_GPU") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_AND_GPU; } if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) { @@ -1228,6 +1216,9 @@ std::unique_ptr CreateExecutionProviderInstance( if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) { coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM; } + } else { + // read from provider_options + return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); } } @@ -1583,7 +1574,8 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("cann", []() { return OrtDevice::NPU; }) .def_static("fpga", []() { return OrtDevice::FPGA; }) .def_static("npu", []() { return OrtDevice::NPU; }) - .def_static("dml", []() { return OrtDevice::GPU; }) + .def_static("dml", []() { return OrtDevice::DML; }) + .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 225931533615d..995341b0f8dc0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -24,7 +24,7 @@ struct OrtStatus { char msg[1]; // a null-terminated string }; -#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_TVM BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN +#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN BACKEND_WEBGPU #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/providers.h" #include "core/providers/provider_factory_creators.h" @@ -75,12 +75,6 @@ struct OrtStatus { #define BACKEND_OPENVINO "" #endif -#ifdef USE_TVM -#define BACKEND_TVM "-TVM" -#else -#define BACKEND_TVM "" -#endif - #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" #else @@ -111,6 +105,12 @@ struct OrtStatus { #define BACKEND_CANN "" #endif +#if USE_WEBGPU +#define BACKEND_WEBGPU "-WEBGPU" +#else +#define BACKEND_WEBGPU "" +#endif + #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" @@ -135,9 +135,6 @@ extern std::string openvino_device_type; } } // namespace onnxruntime #endif -#ifdef USE_TVM -#include "core/providers/tvm/tvm_ep_options.h" -#endif #ifdef USE_ACL #include "core/providers/acl/acl_provider_factory.h" #endif @@ -438,15 +435,12 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Dnnl(const OrtDnnlProviderOptions* params); -#ifdef USE_TVM -std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); -std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); -#endif std::shared_ptr CreateExecutionProviderFactory_ACL(bool enable_fast_math); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); +std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_CoreML(uint32_t flags); constexpr const char* kDefaultExecutionProviderEntry = "GetProvider"; diff --git a/onnxruntime/python/providers/tvm/__init__.py b/onnxruntime/python/providers/tvm/__init__.py deleted file mode 100644 index 4bcbc0bfef586..0000000000000 --- a/onnxruntime/python/providers/tvm/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -JIT interface implementing packed functions that -import and compile frontend models -""" -from .ort import ANSOR_TYPE, AUTO_TVM_TYPE, onnx_compile # noqa: F401 diff --git a/onnxruntime/python/providers/tvm/extend_python_file.py b/onnxruntime/python/providers/tvm/extend_python_file.py deleted file mode 100644 index 65902619f8150..0000000000000 --- a/onnxruntime/python/providers/tvm/extend_python_file.py +++ /dev/null @@ -1,54 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import argparse -import textwrap - - -def rewrite_target_file(target): - with open(target, "a") as f: - f.write( - textwrap.dedent( - """ - import warnings - - try: - # This import is necessary in order to delegate the loading of libtvm.so to TVM. - import tvm - except ImportError as e: - warnings.warn( - f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}" - ) - try: - # Working between the C++ and Python parts in TVM EP is done using the PackedFunc and - # Registry classes. In order to use a Python function in C++ code, it must be registered in - # the global table of functions. Registration is carried out through the JIT interface, - # so it is necessary to call special functions for registration. - # To do this, we need to make the following import. - import onnxruntime.providers.tvm - except ImportError as e: - warnings.warn( - f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}" - ) - """ - ) - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--target_file", - type=str, - required=True, - help="Path to the file to be expanded.", - ) - args = parser.parse_args() - rewrite_target_file(args.target_file) - - -if __name__ == "__main__": - main() diff --git a/onnxruntime/python/providers/tvm/ort.py b/onnxruntime/python/providers/tvm/ort.py deleted file mode 100644 index be6d23f39c532..0000000000000 --- a/onnxruntime/python/providers/tvm/ort.py +++ /dev/null @@ -1,140 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import collections -import copy -import logging -import os - -import onnx -import tvm -from tvm import auto_scheduler, autotvm, relay -from tvm.contrib import graph_executor -from tvm.relay import vm - -log = logging.getLogger("tvm_ep") - -ANSOR_TYPE = "Ansor" -AUTO_TVM_TYPE = "AutoTVM" - - -@tvm.register_func("tvm_onnx_import_and_compile") -def onnx_compile( - model_string, - model_path, - executor, - target, - target_host, - opt_level, - opset, - freeze_params, - input_shapes, - nhwc=False, - tuning_logfile="", - tuning_type=AUTO_TVM_TYPE, -): - def get_tvm_executor(irmod, executor, target, params): - if executor == "vm": - log.info("Build TVM virtual machine") - lib = vm.compile( - copy.deepcopy(irmod), - target, - params=params, - ) - elif executor == "graph": - log.info("Build TVM graph executor") - lib = relay.build(irmod, target=target, params=params) - else: - log.error(f'Executor type {executor} is unsupported. Only "vm" and "graph" types are supported') - return None - return lib - - model = onnx.load_model_from_string(bytes(model_string)) - if model_path: - base_dir = os.path.dirname(os.path.abspath(model_path)) - onnx.load_external_data_for_model(model, base_dir) - - # Collect only feed input names from all input names - all_input_names = [node.name for node in model.graph.input] - all_initializer = [node.name for node in model.graph.initializer] - net_feed_input_names = list(set(all_input_names) - set(all_initializer)) - - # Match names and input shapes - all_input_mapping = [(name, shape) for (name, shape) in zip(all_input_names, input_shapes)] - # Using an ordereddict maintains input ordering. - shape_dict = collections.OrderedDict(all_input_mapping) - # Get only feed input pairs - feed_shape_dict = {} - for name in net_feed_input_names: - feed_shape_dict[name] = shape_dict[name] - - irmod, params = relay.frontend.from_onnx(model, feed_shape_dict, opset=opset, freeze_params=freeze_params) - irmod = relay.transform.DynamicToStatic()(irmod) - - # Tuning file can be set by client through ep options - if not tuning_logfile: - tuning_logfile = os.getenv("AUTOTVM_TUNING_LOG") - lib = None - tvm_target = tvm.target.Target(target, host=target_host) - if tuning_logfile: - if tuning_type == ANSOR_TYPE: - desired_layouts = { - "nn.conv2d": ["NHWC", "default"], - "nn.conv2d_transpose": ["NHWC", "default"], - "nn.upsampling": ["NHWC", "default"], - "vision.roi_align": ["NHWC", "default"], - } - log.info("Use tuning file from %s: %s", ANSOR_TYPE, tuning_logfile) - with auto_scheduler.ApplyHistoryBest(tuning_logfile): # noqa: SIM117 - with tvm.transform.PassContext( - opt_level=opt_level, - config={ - "relay.backend.use_auto_scheduler": True, - "relay.FuseOps.max_depth": 30, - }, - ): - if nhwc: - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.ConvertLayout(desired_layouts), - relay.transform.EliminateCommonSubexpr(), - relay.transform.FoldConstant(), - ] - ) - irmod = seq(irmod) - lib = get_tvm_executor(irmod, executor, tvm_target, params) - elif tuning_type == AUTO_TVM_TYPE: - with relay.build_config(opt_level=opt_level): - log.info("Use tuning file from %s: %s", AUTO_TVM_TYPE, tuning_logfile) - with autotvm.apply_history_best(tuning_logfile): - lib = get_tvm_executor(irmod, executor, tvm_target, params) - else: - log.error( - f"Tuning log type {tuning_type} is unsupported. " - f"Only {ANSOR_TYPE} and {AUTO_TVM_TYPE} types are supported" - ) - return None - else: - with tvm.transform.PassContext(opt_level=opt_level): - lib = get_tvm_executor(irmod, executor, tvm_target, params) - - if lib is None: - return None - - ctx = tvm.device(target, 0) - if executor == "vm": - m = tvm.runtime.vm.VirtualMachine(lib, ctx) - elif executor == "graph": - m = graph_executor.GraphModule(lib["default"](ctx)) - else: - print( - f"ERROR: Executor type {executor} is unsupported. ", - 'Only "vm" and "graph" types are supported', - ) - return None - - return m.module diff --git a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py index 1bc22eb0e5713..b7d32fd6b2353 100644 --- a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py +++ b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py @@ -20,135 +20,158 @@ def __init__(self): self.dim = [] -def is_quantized_data_type(qnn_data_type): - # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_FIXED_POINT_16 - return qnn_data_type == 0x0408 or qnn_data_type == 0x0416 or qnn_data_type == 0x0308 or qnn_data_type == 0x0316 - - -def qnn_data_type_to_onnx_data_type(qnn_data_type): - # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8 - if qnn_data_type == 0x0408 or qnn_data_type == 0x0108: - return TensorProto.UINT8 - # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16 - elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116: - return TensorProto.UINT16 - # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32 - elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132: - return TensorProto.UINT32 - # QNN_DATATYPE_UINT_64 - elif qnn_data_type == 0x0164: - return TensorProto.UINT64 - # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8 - elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008: - return TensorProto.INT8 - # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16 - elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016: - return TensorProto.INT16 - # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32 - elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032: - return TensorProto.INT32 - # QNN_DATATYPE_INT_64 - elif qnn_data_type == 0x0064: - return TensorProto.INT64 - # QNN_DATATYPE_FLOAT_16 - elif qnn_data_type == 0x0216: - return TensorProto.FLOAT16 - # QNN_DATATYPE_FLOAT_32 - elif qnn_data_type == 0x0232: - return TensorProto.FLOAT - # QNN_DATATYPE_BOOL_8 - elif qnn_data_type == 0x0508: - return TensorProto.BOOL +def is_quantized_data_type(qnn_data_type, is_converter_json): + if is_converter_json: + # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_FIXED_POINT_16 + return qnn_data_type == 0x0408 or qnn_data_type == 0x0416 or qnn_data_type == 0x0308 or qnn_data_type == 0x0316 else: - return TensorProto.UNDEFINED - - -def parse_qnn_json_file(qnn_json_file_path, qnn_input_tensor_dic, qnn_output_tensor_dic): - with open(qnn_json_file_path) as qnn_json_file: - qnn_json = json.load(qnn_json_file) - assert "graph" in qnn_json, "QNN converted json file not valid. Can't find graph." - assert "tensors" in qnn_json["graph"], "QNN converted json file not valid. Can't find tensors." - for qnn_tensor_name, qnn_tensor_attribute in qnn_json["graph"]["tensors"].items(): - # type:0 - QNN input tensor, type:1 - QNN output tensor - assert ( - "type" in qnn_tensor_attribute - and "data_type" in qnn_tensor_attribute - and "dims" in qnn_tensor_attribute - ), "QNN converted json file not valid. Can't find some keys from tensors" - - # Get all graph inputs - if qnn_tensor_attribute["type"] == 0: - qnn_tensor = QnnTensorStruct() - qnn_tensor.name = qnn_tensor_name - qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"]) - qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"]) - qnn_tensor.dim = qnn_tensor_attribute["dims"] - if ( - qnn_tensor_attribute["quant_params"]["definition"] == 1 - and qnn_tensor_attribute["quant_params"]["encoding"] == 0 - ): - qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] - qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] - qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor - - # Get all graph outputs - if qnn_tensor_attribute["type"] == 1: - qnn_tensor = QnnTensorStruct() - qnn_tensor.name = qnn_tensor_name - qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"]) - qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"]) - qnn_tensor.dim = qnn_tensor_attribute["dims"] - if ( - qnn_tensor_attribute["quant_params"]["definition"] == 1 - and qnn_tensor_attribute["quant_params"]["encoding"] == 0 - ): - qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] - qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] - qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor + return ( + qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8" + or qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16" + or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8" + or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16" + ) - assert ( - len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1 - ), "Converted QNN model not valid. It should have at least 1 input & 1 output." +def qnn_data_type_to_onnx_data_type(qnn_data_type, is_converter_json): + if is_converter_json: + # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8 + if qnn_data_type == 0x0408 or qnn_data_type == 0x0108: + return TensorProto.UINT8 + # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16 + elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116: + return TensorProto.UINT16 + # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32 + elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132: + return TensorProto.UINT32 + # QNN_DATATYPE_UINT_64 + elif qnn_data_type == 0x0164: + return TensorProto.UINT64 + # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8 + elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008: + return TensorProto.INT8 + # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16 + elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016: + return TensorProto.INT16 + # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32 + elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032: + return TensorProto.INT32 + # QNN_DATATYPE_INT_64 + elif qnn_data_type == 0x0064: + return TensorProto.INT64 + # QNN_DATATYPE_FLOAT_16 + elif qnn_data_type == 0x0216: + return TensorProto.FLOAT16 + # QNN_DATATYPE_FLOAT_32 + elif qnn_data_type == 0x0232: + return TensorProto.FLOAT + # QNN_DATATYPE_BOOL_8 + elif qnn_data_type == 0x0508: + return TensorProto.BOOL + else: + return TensorProto.UNDEFINED + else: + # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8 + if qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_UINT_8": + return TensorProto.UINT8 + # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16 + elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_UINT_16": + return TensorProto.UINT16 + # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32 + elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_UINT_32": + return TensorProto.UINT32 + # QNN_DATATYPE_UINT_64 + elif qnn_data_type == "QNN_DATATYPE_UINT_64": + return TensorProto.UINT64 + # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8 + elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_INT_8": + return TensorProto.INT8 + # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16 + elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_INT_16": + return TensorProto.INT16 + # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32 + elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_INT_32": + return TensorProto.INT32 + # QNN_DATATYPE_INT_64 + elif qnn_data_type == "QNN_DATATYPE_INT_64": + return TensorProto.INT64 + # QNN_DATATYPE_FLOAT_16 + elif qnn_data_type == "QNN_DATATYPE_FLOAT_16": + return TensorProto.FLOAT16 + # QNN_DATATYPE_FLOAT_32 + elif qnn_data_type == "QNN_DATATYPE_FLOAT_32": + return TensorProto.FLOAT + # QNN_DATATYPE_BOOL_8 + elif qnn_data_type == "QNN_DATATYPE_BOOL_8": + return TensorProto.BOOL + else: + return TensorProto.UNDEFINED -# Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file -# uses channel last data layout and 8 bits or 16 bits for input and output. -# This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model -# and inserts Cast, Transpose nodes to Onnx model if required -def main(): - parser = ArgumentParser("Generate Onnx model which includes the QNN context binary.") - parser.add_argument("-b", "--qnn_bin", help="Required. Path to Qnn context binary file.", required=True, type=str) - parser.add_argument( - "-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str - ) - parser.add_argument( - "--disable_embed_mode", - action="store_true", - default=False, - help="Set embed_mode=1 which mean embed Qnn context binary into the onnx model. Otherwise, set context binary file path in the onnx model", - ) - args = parser.parse_args() - # Parse Qnn model_net.json file to get the graph input output information - qnn_input_tensor_dic = {} - qnn_output_tensor_dic = {} - parse_qnn_json_file(args.qnn_json, qnn_input_tensor_dic, qnn_output_tensor_dic) +def parse_qnn_converter_json_file(qnn_convert_json, qnn_input_tensor_dic, qnn_output_tensor_dic): + is_qnn_converter_json = True + for qnn_tensor_name, qnn_tensor_attribute in qnn_convert_json["graph"]["tensors"].items(): + # type:0 - QNN input tensor, type:1 - QNN output tensor + assert ( + "type" in qnn_tensor_attribute and "data_type" in qnn_tensor_attribute and "dims" in qnn_tensor_attribute + ), "QNN converted json file not valid. Can't find some keys from tensors" - if args.disable_embed_mode: - ep_cache_context_content = args.qnn_bin - ctx_embed_mode = 0 - else: - with open(args.qnn_bin, "rb") as file: - ep_cache_context_content = file.read() - ctx_embed_mode = 1 + # Get all graph inputs + if qnn_tensor_attribute["type"] == 0: + qnn_tensor = QnnTensorStruct() + qnn_tensor.name = qnn_tensor_name + qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type( + qnn_tensor_attribute["data_type"], is_qnn_converter_json + ) + qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json) + qnn_tensor.dim = qnn_tensor_attribute["dims"] + if ( + qnn_tensor_attribute["quant_params"]["definition"] == 1 + and qnn_tensor_attribute["quant_params"]["encoding"] == 0 + ): + qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] + qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] + qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor + + # Get all graph outputs + if qnn_tensor_attribute["type"] == 1: + qnn_tensor = QnnTensorStruct() + qnn_tensor.name = qnn_tensor_name + qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type( + qnn_tensor_attribute["data_type"], is_qnn_converter_json + ) + qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json) + qnn_tensor.dim = qnn_tensor_attribute["dims"] + if ( + qnn_tensor_attribute["quant_params"]["definition"] == 1 + and qnn_tensor_attribute["quant_params"]["encoding"] == 0 + ): + qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] + qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] + qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor + assert ( + len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1 + ), "Converted QNN model not valid. It should have at least 1 input & 1 output." + + +def generate_wrapper_onnx_file( + grap_name, + model_file_name, + qnn_input_tensor_dic, + qnn_output_tensor_dic, + disable_embed_mode, + qnn_ctx_file, + quantized_IO, + qnn_sdk_version="unknown", +): graph_nodes = [] ini_list = [] value_infos = [] model_inputs = [] for qnn_input in qnn_input_tensor_dic.values(): - if qnn_input.is_quantized: + if qnn_input.is_quantized and not quantized_IO: q_scale_input_name = qnn_input.name + "_scale" q_offset_input_name = qnn_input.name + "_zp" q_scale = helper.make_tensor(q_scale_input_name, TensorProto.FLOAT, [], [qnn_input.scale]) @@ -170,13 +193,22 @@ def main(): else: model_inputs.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim)) + if disable_embed_mode: + ep_cache_context_content = qnn_ctx_file + ctx_embed_mode = 0 + else: + with open(qnn_ctx_file, "rb") as file: + ep_cache_context_content = file.read() + ctx_embed_mode = 1 + qnn_ep_context_node = helper.make_node( "EPContext", - name="QnnContext", + name=grap_name, inputs=qnn_input_tensor_dic.keys(), outputs=qnn_output_tensor_dic.keys(), ep_cache_context=ep_cache_context_content, embed_mode=ctx_embed_mode, + ep_sdk_version=qnn_sdk_version, source="Qnn", domain="com.microsoft", ) @@ -184,7 +216,7 @@ def main(): model_outputs = [] for qnn_output in qnn_output_tensor_dic.values(): - if qnn_output.is_quantized: + if qnn_output.is_quantized and not quantized_IO: dq_scale_input_name = qnn_output.name + "_scale" dq_offset_input_name = qnn_output.name + "_zp" dq_scale = helper.make_tensor(dq_scale_input_name, TensorProto.FLOAT, [], [qnn_output.scale]) @@ -214,7 +246,120 @@ def main(): model_def = helper.make_model(graph_def, producer_name="MS") - onnx.save(model_def, args.qnn_json.replace(".json", "_qnn_ctx.onnx")) + onnx.save(model_def, model_file_name) + + +# parse Qnn graph from the json file that extracted from context binary file +def parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic): + is_qnn_converter_json = False + graph_name = qnn_graph["info"]["graphName"] + raw_inputs = qnn_graph["info"]["graphInputs"] + raw_outputs = qnn_graph["info"]["graphOutputs"] + + for raw_input in raw_inputs: + tensor_info = raw_input["info"] + qnn_tensor = QnnTensorStruct() + qnn_tensor.name = tensor_info["name"] + qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json) + qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json) + qnn_tensor.dim = tensor_info["dimensions"] + if ( + tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED" + and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET" + ): + qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"] + qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"] + qnn_input_tensor_dic[qnn_tensor.name] = qnn_tensor + + for raw_output in raw_outputs: + tensor_info = raw_output["info"] + qnn_tensor = QnnTensorStruct() + qnn_tensor.name = tensor_info["name"] + qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json) + qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json) + qnn_tensor.dim = tensor_info["dimensions"] + if ( + tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED" + and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET" + ): + qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"] + qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"] + qnn_output_tensor_dic[qnn_tensor.name] = qnn_tensor + + assert ( + len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1 + ), "Converted QNN model not valid. It should have at least 1 input & 1 output." + + return graph_name + + +# Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file +# uses channel last data layout and 8 bits or 16 bits for input and output. +# This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model +# and inserts Cast, Transpose nodes to Onnx model if required +def main(): + parser = ArgumentParser("Generate Onnx model which includes the QNN context binary.") + parser.add_argument("-b", "--qnn_bin", help="Required. Path to Qnn context binary file.", required=True, type=str) + parser.add_argument( + "-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str + ) + parser.add_argument( + "--disable_embed_mode", + action="store_true", + default=False, + help="Set embed_mode=1 which mean embed Qnn context binary into the onnx model. Otherwise, set context binary file path in the onnx model", + ) + parser.add_argument( + "--quantized_IO", + action="store_true", + default=False, + help="QNN converted context binary use quantized data as graph inputs and outputs. Will keep it if quantized_IO=True, otherwise, will insert Q and DQ nodes accordingly to make the graph inputs & outputs as float32 data type.", + ) + args = parser.parse_args() + + # Parse Qnn model_net.json file to get the graph input output information + + with open(args.qnn_json) as qnn_json_file: + qnn_json_obj = json.load(qnn_json_file) + if "graph" in qnn_json_obj and "tensors" in qnn_json_obj["graph"]: + print("This json file is from Qnn converter") + qnn_input_tensor_dic = {} + qnn_output_tensor_dic = {} + parse_qnn_converter_json_file(qnn_json_obj, qnn_input_tensor_dic, qnn_output_tensor_dic) + + generate_wrapper_onnx_file( + "QnnContext", + args.qnn_json.replace(".json", "_qnn_ctx.onnx"), + qnn_input_tensor_dic, + qnn_output_tensor_dic, + args.disable_embed_mode, + args.qnn_bin, + args.quantized_IO, + ) + elif "info" in qnn_json_obj and "graphs" in qnn_json_obj["info"]: + print("This json file is extracted from QNN context binary file") + qnn_version = qnn_json_obj["info"]["buildId"] + for qnn_graph in qnn_json_obj["info"]["graphs"]: + qnn_input_tensor_dic = {} + qnn_output_tensor_dic = {} + graph_name = parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic) + + ctx_file_name = graph_name + "_qnn_ctx.onnx" + if not args.quantized_IO: + ctx_file_name = ctx_file_name.replace(".onnx", "_fp32_io.onnx") + + generate_wrapper_onnx_file( + graph_name, + ctx_file_name, + qnn_input_tensor_dic, + qnn_output_tensor_dic, + args.disable_embed_mode, + args.qnn_bin, + args.quantized_IO, + qnn_version, + ) + else: + print("json file unrecoginized.") if __name__ == "__main__": diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 9d397499d45a4..712e15a6a1ca9 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -10,6 +10,7 @@ from .quantize import DynamicQuantConfig # noqa: F401 from .quantize import QuantizationMode # noqa: F401 from .quantize import StaticQuantConfig # noqa: F401 +from .quantize import get_qdq_config # noqa: F401 from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index b20af5137d206..6235db3234d49 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -19,9 +19,10 @@ from .calibrate import TensorData from .onnx_model import ONNXModel from .quant_utils import ( + DEQUANT_OP_NAME, ONNX_TYPE_TO_NP_TYPE, + QUANT_OP_NAME, TENSOR_NAME_QUANT_SUFFIX, - QuantType, find_by_name, model_has_infer_metadata, normalize_axis, @@ -40,18 +41,26 @@ def __init__(self, **data: Dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") + if k == "axis" and not isinstance(v, int) and v is not None: + raise TypeError(f"Axis value must be an int or None, not {type(v)}.") if k == "scale" and v.dtype not in (np.float32, np.float16): raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}") self.data[k] = v + def get(self, key, default_value=None): + return self.data.get(key, default_value) + def __iter__(self): yield from self.data def __getitem__(self, key): return self.data[key] + def __setitem__(self, key, value): + self.data[key] = value + def __len__(self): return len(self.data) @@ -88,9 +97,10 @@ def __init__( self.force_quantize_no_input_check = ( "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] ) - self.is_weight_symmetric = self.extra_options.get( - "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) - ) + + # If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines + # the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()` + self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None) self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False) self.min_real_range = self.extra_options.get("MinimumRealRange") @@ -131,6 +141,16 @@ def __init__( self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() + def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool: + if self._is_weight_symmetric is not None: + return self._is_weight_symmetric # Return value explicitly set by user. + return weight_quant_type in ( + onnx.TensorProto.INT4, + onnx.TensorProto.INT8, + onnx.TensorProto.INT16, + onnx.TensorProto.FLOAT8E4M3FN, + ) + def quantize_model(self): raise NotImplementedError @@ -160,6 +180,9 @@ def should_quantize_node(self, node): if node.op_type not in self.op_types_to_quantize: return False + if node.op_type in (DEQUANT_OP_NAME, QUANT_OP_NAME): + return False + if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude: return False @@ -230,9 +253,19 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 # TODO: This formula should be explained including why the scale is not estimated for the bias as well. bias_scale = input_scale * weight_scale * beta - quantized_data = (np.asarray(bias_data) / bias_scale).round() - quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max) - quantized_data = quantized_data.astype(np.int32) + # Quantize by dividing by bias_scale + quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64) + quantized_data = quantized_data.round() + + # Clip quantized data to the range of a int32 + int32_min = np.float64(np.iinfo(np.int32).min) + int32_max = np.float64(np.iinfo(np.int32).max) + if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max): + logging.warning( + f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small." + ) + + quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32) # update bias initializer bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims) @@ -282,6 +315,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" @@ -303,10 +337,11 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}" else: - _, _, zero_point, scale, q_weight_data = quantize_data( + symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric + zero_point, scale, q_weight_data = quantize_data( weight_data.flatten(), qType, - quant_overrides.get("symmetric", self.is_weight_symmetric), + quant_overrides.get("symmetric", symmetric), reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range), min_real_range=self.min_real_range, rmin_override=quant_overrides.get("rmin"), @@ -371,6 +406,7 @@ def quantize_weight_per_channel_impl( reduce_range=True, keep_float_weight=False, ): + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) @@ -409,13 +445,7 @@ def quantize_weight_per_channel_impl( if "quant_type" in quant_overrides_for_channels[0]: weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806 - symmetric = quant_overrides_for_channels[0].get( - "symmetric", - ( - self.is_weight_symmetric - or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4) - ), - ) + symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType)) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] scale_list = [] @@ -444,7 +474,7 @@ def quantize_weight_per_channel_impl( ), f"Unexpected type {type(quantized_per_channel_data)}" else: - _, _, zero_point, scale, quantized_per_channel_data = quantize_data( + zero_point, scale, quantized_per_channel_data = quantize_data( per_channel_data.flatten(), weight_qType, symmetric, @@ -529,4 +559,6 @@ def adjust_tensor_ranges(self): self.tensors_range[node.input[0]] = td # Adjust Softmax to range from 0.0 to 1.0 elif node.op_type == "Softmax": + if not self.should_quantize_node(node): + continue self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 16ad36c48cc74..1d91141a117ad 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -225,6 +225,252 @@ def __init__( self.accuracy_level = accuracy_level +class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + tokenizer_dir, + dataset_name="cnn", + cache_dir="./cache", + calibration_method="awq_lite", + ): + """ + Configuration for the nvidia_awq quantization method. + + Args: + tokenizer_dir (str): pathof the tokenizer dir. + dataset_name (str): Name of the dataset. + cache_dir (str): Directory for caching. + calibration_method (str): calib method for nvidia_awq. + """ + # Import torch and DataLoader + try: + import torch + from torch.utils.data import DataLoader + + self.torch = torch + self.DataLoader = DataLoader + except ImportError: + print( + "Error: The 'torch' library is required but not installed. Please install it using 'pip install torch'." + ) + raise ImportError("torch is not installed. Exiting.") from None + + # Import datasets + try: + from datasets import load_dataset + + self.load_dataset = load_dataset + except ImportError: + print( + "Error: The 'datasets' library is required but not installed. Please install it using 'pip install datasets'." + ) + raise ImportError("datasets is not installed. Exiting.") from None + + # Import transformers + try: + from transformers import AutoConfig, AutoTokenizer + + self.AutoConfig = AutoConfig + self.AutoTokenizer = AutoTokenizer + except ImportError: + print( + "Error: The 'transformers' library is required but not installed. Please install it using 'pip install transformers'." + ) + raise ImportError("transformers is not installed. Exiting.") from None + + super().__init__( + algorithm="nvidia_awq", + quant_format=QuantFormat.QDQ, + op_types_to_quantize=None, # Assuming op_types_to_quantize is handled elsewhere + quant_axes=None, # Assuming quant_axes is handled elsewhere + ) + + # Determine the device + device = self.torch.device("cuda" if self.torch.cuda.is_available() else "cpu") + + calib_inputs = self.get_calib_inputs( + dataset_name=dataset_name, + model_name=tokenizer_dir, + cache_dir=cache_dir, + calib_size=32, + batch_size=1, + block_size=512, + device=device, + use_fp16=True, + use_buffer_share=False, + add_past_kv_inputs=True, + max_calib_rows_to_load=128, + add_position_ids=True, + ) + + self.calibration_data_reader = calib_inputs + self.calibration_method = calibration_method + + def make_model_input( + self, + config, + input_ids_arg, + attention_mask_arg, + add_past_kv_inputs, + device, + use_fp16, + use_buffer_share, + add_position_ids, + ): + # Access torch from the instance variable + torch = self.torch + + input_ids = input_ids_arg + attention_mask = attention_mask_arg + + if isinstance(input_ids_arg, list): + input_ids = torch.tensor(input_ids_arg, device=device, dtype=torch.int64) + attention_mask = torch.tensor(attention_mask_arg, device=device, dtype=torch.int64) + + inputs = { + "input_ids": input_ids.contiguous(), + "attention_mask": attention_mask.contiguous(), + } + + if add_position_ids: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + inputs["position_ids"] = position_ids.contiguous() + + if add_past_kv_inputs: + torch_dtype = torch.float16 if use_fp16 else torch.float32 + batch_size, sequence_length = input_ids.shape + max_sequence_length = config.max_position_embeddings + num_heads, head_size = ( + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + for i in range(config.num_hidden_layers): + past_key = torch.zeros( + batch_size, + num_heads, + max_sequence_length if use_buffer_share else 0, + head_size, + device=device, + dtype=torch_dtype, + ) + past_value = torch.zeros( + batch_size, + num_heads, + max_sequence_length if use_buffer_share else 0, + head_size, + device=device, + dtype=torch_dtype, + ) + inputs.update( + { + f"past_key_values.{i}.key": past_key.contiguous(), + f"past_key_values.{i}.value": past_value.contiguous(), + } + ) + + return inputs + + def get_calib_inputs( + self, + dataset_name, + model_name, + cache_dir, + calib_size, + batch_size, + block_size, + device, + use_fp16, + use_buffer_share, + add_past_kv_inputs, + max_calib_rows_to_load, + add_position_ids, + ): + # Access transformers and datasets from the instance variables + auto_config = self.AutoConfig + auto_tokenizer = self.AutoTokenizer + load_dataset = self.load_dataset + + config = auto_config.from_pretrained( + model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True + ) + tokenizer = auto_tokenizer.from_pretrained( + model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.pad_token = tokenizer.eos_token + + assert calib_size <= max_calib_rows_to_load, "calib size should be no more than max_calib_rows_to_load" + + if "cnn" in dataset_name: + dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select(range(max_calib_rows_to_load)) + column = "article" + elif "pile" in dataset_name: + dataset2 = load_dataset("mit-han-lab/pile-val-backup", split="validation") + column = "text" + else: + raise ValueError(f'dataset "{dataset_name}" not supported') + + dataset2 = dataset2[column][:calib_size] + batch_encoded = tokenizer.batch_encode_plus( + dataset2, return_tensors="pt", padding=True, truncation=True, max_length=block_size + ) + batch_encoded = batch_encoded.to(device) + batch_encoded_input_ids = batch_encoded["input_ids"] + batch_encoded_attention_mask = batch_encoded["attention_mask"] + + # Access DataLoader from the instance variable + data_loader = self.DataLoader + + calib_dataloader_input_ids = data_loader(batch_encoded_input_ids, batch_size=batch_size, shuffle=False) + calib_dataloader_attention_mask = data_loader( + batch_encoded_attention_mask, batch_size=batch_size, shuffle=False + ) + + assert len(calib_dataloader_input_ids.dataset) == len(calib_dataloader_attention_mask.dataset) + assert len(calib_dataloader_input_ids) == len(calib_dataloader_attention_mask) + + number_of_batched_samples = calib_size // batch_size + + batched_input_ids = [] + for idx, data in enumerate(calib_dataloader_input_ids): + batched_input_ids.append(data) + if idx == (number_of_batched_samples - 1): + break + + batched_attention_mask = [] + for idx, data in enumerate(calib_dataloader_attention_mask): + batched_attention_mask.append(data) + if idx == (number_of_batched_samples - 1): + break + + print( + f"\n--Quantize-Script-- number_of_batched_samples={number_of_batched_samples}, " + f"batch-input-ids-list-len={len(batched_input_ids)}, batched_attention_mask={len(batched_attention_mask)}\n" + ) + + batched_inputs_list = [] + for i in range(number_of_batched_samples): + input_ids = batched_input_ids[i] + attention_mask = batched_attention_mask[i] + + inputs = self.make_model_input( + config, + input_ids, + attention_mask, + add_past_kv_inputs, + device, + use_fp16, + use_buffer_share, + add_position_ids, + ) + inputs = {input_name: torch_tensor.cpu().numpy() for input_name, torch_tensor in inputs.items()} + batched_inputs_list.append(inputs) + + print(f"\n--Quantize-Script-- number of batched inputs = {len(batched_inputs_list)}\n") + return batched_inputs_list + + def is_divisible(val1, val2): return int(val2 * np.ceil(val1 / val2)) == val1 @@ -777,6 +1023,49 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP return results +class NVAWQWeightOnlyQuantizer: + def __init__( + self, + config: NVAWQWeightOnlyQuantConfig, + ): + self.config = config + + def quantize_awq(self, model: ModelProto | str) -> ModelProto: + """ + Perform nvidia_awq quantization using ModelOpt's int4 quantize function. + + Args: + model (ModelProto): The ONNX model to quantize. + + Returns: + ModelProto: The quantized ONNX model. + """ + try: + from modelopt.onnx.quantization.int4 import quantize as quantize_int4 + except ImportError: + print( + "Please ensure that the 'modelopt' package is installed. Please install it using pip install nvidia_modelopt." + ) + raise ImportError( + "modelopt is not installed. Please install it using pip install nvidia_modelopt. Exiting." + ) from None + + logger.info("Starting nvidia_awq quantization...") + + # Prepare calibration inputs + calib_inputs = self.config.calibration_data_reader + + # Perform quantization using ModelOpt's int4 quantize function + quantized_model = quantize_int4( + model, + calibration_method=self.config.calibration_method, + calibration_data_reader=calib_inputs, + ) + + logger.info("Completed nvidia_awq quantization.") + return quantized_model + + # TODO(fajin): change class name class MatMul4BitsQuantizer: """ @@ -821,6 +1110,7 @@ def __init__( self.nodes_to_exclude = set(nodes_to_exclude) self.nodes_to_include = set(nodes_to_include) if nodes_to_include else None self.node_quantizer = None + if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( block_size=block_size, @@ -835,6 +1125,8 @@ def __init__( self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) elif algo_config.algorithm == "DEFAULT": self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "nvidia_awq": + self.node_quantizer = NVAWQWeightOnlyQuantizer(self.algo_config) def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] @@ -966,6 +1258,16 @@ def process(self): self._process_subgraph(graph_stack) self.model.clean_initializers() + elif self.algo_config.algorithm == "nvidia_awq": + + # Handle nvidia_awq quantization + logger.info("Processing nvidia_awq quantization...") + self.model = self.node_quantizer.quantize_awq( + self.model.model if self.model_path is None else self.model_path + ) + logger.info("Completed nvidia_awq quantization.") + self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel + self.model.clean_initializers() else: # use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm try: @@ -1012,7 +1314,7 @@ def parse_args(): "--quant_method", default="default", type=str, - choices=["default", "hqq", "rtn", "gptq"], + choices=["default", "hqq", "rtn", "gptq", "nvidia_awq"], help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor", ) parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") @@ -1076,7 +1378,33 @@ def parse_args(): "Specify the axis to quantize for an op. Default {MatMul:0, Gather:1}" "Example: --quant_axes MatMul:0 Gather:1", ) - + # Group arguments specific to nvidia_awq + nv_awq_config = parser.add_argument_group("nvidia_awq", "Arguments specific to nvidia_awq quantization") + nv_awq_config.add_argument( + "--calib_dataset_name", + type=str, + default="cnn", + help="Name of the calibration dataset for nvidia_awq.", + ) + nv_awq_config.add_argument( + "--tokenizer_dir", + type=str, + required=False, + help="Path of the tokenizer dir.", + ) + nv_awq_config.add_argument( + "--calibration_method", + type=str, + required=False, + choices=["awq", "awq_clip"], + help="Support two options, awq implementation and weight clipping.", + ) + nv_awq_config.add_argument( + "--cache_dir", + type=str, + default="./cache", + help="Cache directory for calibration data.", + ) return parser.parse_args() @@ -1117,6 +1445,27 @@ def parse_args(): quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize) elif args.quant_method == "gptq": quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize) + elif args.quant_method == "nvidia_awq": + + if quant_format == QuantFormat.QOperator: + logger.warning("QOperator is not applicable to nvidia_awq. overriding the value to QDQ") + quant_format = QuantFormat.QDQ + + model = input_model_path + if args.calibration_method is not None: + if args.calibration_method == "awq": + calibration_method = "awq_lite" + else: + calibration_method = "awq_clip" + else: + calibration_method = "awq_lite" + + quant_config = NVAWQWeightOnlyQuantConfig( + dataset_name=args.calib_dataset_name, + tokenizer_dir=args.tokenizer_dir, + cache_dir=args.cache_dir, + calibration_method=calibration_method, + ) else: raise ValueError(f"Unsupported quantization method: {args.quant_method}") diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 174bf5fd1509c..43105550139de 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -296,6 +296,26 @@ def get_largest_node_name_suffix(self, node_name_prefix): return suffix + def get_largest_initializer_name_suffix(self, initializer_name_prefix): + """ + Gets the largest initializer name integer suffix for all initializer names that begin + with `initializer_name_prefix`. This can be used to create unique initializer names. + + Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if + `initializer_name_prefix` is 'my_weight_'. + """ + suffix = -1 + + for initializer in self.model.graph.initializer: + if initializer.name.startswith(initializer_name_prefix): + try: + index = int(initializer.name[len(initializer_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index e1e4a4f724fdc..424f9b7e180a3 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -942,7 +942,7 @@ def _dequantize_value(self, value_name): self.model.model.producer_name == "onnx-quantizer" and scale_init is not None ): # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 5f3c1231e62d6..b3e9ddb5e6278 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -1,3 +1,12 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +from typing import Any + +import numpy as np import onnx from ..quant_utils import ( @@ -8,6 +17,7 @@ quantize_nparray, ) from .base_operator import QuantOperatorBase +from .qdq_base_operator import QDQOperatorBase class QPad(QuantOperatorBase): @@ -98,3 +108,65 @@ def quantize(self): node.input[0] = quantized_input_value.q_name node.output[0] = quantized_output_value.q_name self.quantizer.new_nodes += [node] + + +class QDQPad(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None: + """ + Returns the Pad's constant padding value. Returns `None` if the padding value is + not constant (i.e., comes from a dynamic input). + """ + const_val = None + onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0]) + if onnx_tensor_type is None: + return None + + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type) + if self.quantizer.opset_version < 11: + const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype) + elif len(self.node.input) >= 3 and self.node.input[2]: + const_val = self.quantizer.model.get_constant_value(self.node.input[2]) + else: + const_val = np.array(0, dtype=np_dtype) + + return const_val + + def _should_quantize_output_same_as_input(self) -> bool: + """ + Returns true if Pad's output should use the same quantization parameters as input[0] + """ + attrs_dict = {} + for attribute in self.node.attribute: + kv = attribute_to_kwarg(attribute) + attrs_dict.update(kv) + + pad_mode = attrs_dict.get("mode", b"constant") + if pad_mode in (b"reflect", b"edge", b"wrap"): + # These modes pad the output with a value that already exists in the input. + # So, we can quantize the output the same as the input. + return True + + # For 'constant' mode, if padding with 0, we can also quantize the output the same as the input + # because our quantization floating-point range always includes 0. + if pad_mode == b"constant": + pad_val = self._get_pad_const_val(attrs_dict) + if pad_val is not None and pad_val.dtype in (np.float32, np.float16): + return float(pad_val.item()) == 0 + + return False + + def quantize(self): + assert self.node.op_type == "Pad" + + for input_name in self.node.input: + if input_name: + self.quantizer.quantize_activation_tensor(input_name) + + if not self.disable_qdq_for_node_output: + if self._should_quantize_output_same_as_input(): + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) + else: + self.quantizer.quantize_activation_tensor(self.node.output[0]) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index b71f332252850..5552a4451c542 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -20,6 +20,7 @@ from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, + ONNX_TYPE_TO_NP_TYPE, QUANT_OP_NAME, QuantizedValue, QuantizedValueType, @@ -30,12 +31,14 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_data_quant_params, compute_scale_zp, compute_scale_zp_float8, find_by_name, get_qmin_qmax_for_qType, ms_domain, normalize_axis, + quantize_onnx_initializer, tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -86,6 +89,18 @@ class QDQTensorQuantParams: converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + def get_for_consumer(self, consumer_node_name) -> QuantizationParams: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + # Holds scale and zero_point initializer TensorProtos. @dataclass @@ -153,8 +168,8 @@ def __init__( op_types_to_quantize, extra_options, ) - self.tensors_to_quantize = {} - self.bias_to_quantize = {} + self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {} + self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {} self.nodes_to_remove = [] @@ -180,7 +195,11 @@ def __init__( # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) - self.tensor_to_its_receiving_nodes = {} + self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {} + + # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor. + # Populated for input models with some pre-quantized weights (typically via a different tool). + self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) @@ -191,6 +210,9 @@ def __init__( # Used in the QDQRemovableActivation class. self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False) + # Let user disable adjustment of weight scales for bias inputs that are quantized to int32. + self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False) + # The ONNX spec did not support 16-bit Q/DQ ops before opset 21. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types # are 16-bit or 4-bit integers. @@ -213,6 +235,7 @@ def __init__( self.qdq_op_domain = ms_domain self.quantization_params = self.calc_graph_quant_params() + self.initializer_quant_params: dict[str, QuantizationParams] = {} # Map of all original value names to quantized value names self.quantized_value_map = {} @@ -328,6 +351,18 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") + def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto: + """ + Duplicates an existing initializer and adds it to the model. Returns the new initializer. + """ + name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1 + new_initializer_name = f"{initializer.name}{name_suffix}" + new_initializer = onnx.TensorProto() + new_initializer.CopyFrom(initializer) + new_initializer.name = new_initializer_name + self.model.add_initializer(new_initializer) + return new_initializer + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): """ Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that @@ -353,15 +388,160 @@ def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, be self.quantize_weight_tensor(bias_name) return - weight = find_by_name(bias_name, self.model.initializer()) - if weight is not None: - if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - if bias_name not in self.bias_to_quantize: - self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) - else: - logging.warning(f"Bias {bias_name} has already been marked for quantization") - else: - logging.warning(f"Expected {bias_name} to be a weight") + bias_initializer = find_by_name(bias_name, self.model.initializer()) + if bias_initializer is None: + logging.warning(f"Expected bias '{bias_name}' to be an initializer") + return + + if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): + logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer") + return + + actual_bias_name = bias_name + if bias_name in self.bias_to_quantize: + # This bias input is consumed by two different nodes. We need to duplicate the bias so that + # each node has its own bias input. This is necessary because the bias's scale is computed + # from the node's other input scales. + new_bias_initializer = self._dup_initializer(bias_initializer) + actual_bias_name = new_bias_initializer.name + + # Replace this node's bias input + self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name}) + logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'") + + # Add this to our list of biases to quantize. + self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + + def _adjust_weight_scale_for_int32_bias( + self, + input_scale: np.ndarray, + weight_scale: np.ndarray, + weight_name: str, + bias_tp: onnx.TensorProto, + is_per_channel: bool, + ) -> tuple[bool, np.ndarray | None]: + """ + Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small. + A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to + be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be + increased to prevent this from happening. + + Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following + reference: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252 + + :param input_scale: The input's scale. + :param weight_scale: The weight scale to potentially adjust. + :param weight_name: The weight initializer's name. Used for logging. + :param bias_tp: The bias ONNX initializer. + :param is_per_channel: True if the bias and weight are quantized per-channel. + :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale. + """ + if not weight_scale.size: + return False, None + + bias_float_data = tensor_proto_to_array(bias_tp) + + int32_info = np.iinfo(np.int32) + multiplicative_epsilon = 1.0001 + qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64) + weight_scale_dtype = weight_scale.dtype + updated_an_elem = False + + if not is_per_channel: + rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64)) + rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to " + f"ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + elif weight_scale.shape and len(weight_scale.shape) == 1: + # per-channel case + num_elems = weight_scale.shape[0] + + for i in range(num_elems): + bias_rmax = np.abs(bias_float_data[i]) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} " + f"to ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale[i] = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + + return updated_an_elem, weight_scale + + def _adjust_weight_quant_params_for_bias_tensors(self): + """ + Iterates through all bias inputs that should be quantized to int32. If the intended + bias scale (equal to input_scale * weight_scale) is too small, this function will increase + the associated weight's scale to ensure the bias does not overflow the int32 range when quantized. + """ + + if self.qdq_disable_weight_adjust_for_int32_bias: + # User passed an extra_option to disable this adjustment. + return + + for bias_name, bias_info in self.bias_to_quantize.items(): + if ( + bias_info.input_name not in self.quantization_params + or bias_info.input_name not in self.tensors_to_quantize + or bias_info.weight_name not in self.initializer_quant_params + ): + continue + + # Get the associated input's scale. + input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name) + input_info = self.tensors_to_quantize[bias_info.input_name] + input_scale = np.asarray( + input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type) + ) + + weight_quant_params = self.initializer_quant_params[bias_info.weight_name] + weight_quant_type = weight_quant_params["quant_type"] + if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16): + continue + + weight_zero_point: np.ndarray = weight_quant_params["zero_point"] + if weight_zero_point.any(): + # Skip if zero_point(s) are not all zero (i.e., symmetric quant) + continue + + weight_scale: np.ndarray = weight_quant_params["scale"] + is_per_channel = weight_quant_params.get("axis", None) is not None + + # Get adjusted weight scales. + did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias( + input_scale, + weight_scale, + bias_info.weight_name, + find_by_name(bias_name, self.model.initializer()), + is_per_channel, + ) + + if did_update_weight_scale: + weight_quant_params["scale"] = new_weight_scale def remove_node(self, node): self.nodes_to_remove.append(node) @@ -379,7 +559,12 @@ def quantize_model(self): if tensor_name not in self.tensor_to_its_receiving_nodes: self.tensor_to_its_receiving_nodes[tensor_name] = [] self.tensor_to_its_receiving_nodes[tensor_name].append(node) + if node.op_type == DEQUANT_OP_NAME: + for tensor_name in node.output: + self.tensor_to_producing_dq[tensor_name] = node + self.initializer_quant_params = self._calc_initializer_quant_params() + self._adjust_weight_quant_params_for_bias_tensors() self._quantize_normal_tensors() self._quantize_sharing_param_tensors() if self.quantize_bias: @@ -475,38 +660,26 @@ def _create_qdq_nodes( ) self.model.add_nodes([qlinear_node, dequant_node]) - def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): + def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto): + """ + Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates + the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer + and adds the sequence (weight_quant -> DQ ->). + """ weight_name = weight_proto.name - if axis is not None: - if self.opset_version < 13: - raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") - - qtype = self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType - if qtype == onnx.onnx_pb.TensorProto.UINT8: - qtype = onnx_proto.TensorProto.INT8 - - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( - weight_name, - # Quantization type is forced to be TensorProto.INT8. - # when the expected value would be (see below) - # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. - # QLinearConv expects to have a unique value for all channels. - # This code does not enforce that but it is necessarily the case when the - # quantization is symmetric (as for INT8). - qtype, - axis, - keep_float_weight=self.add_qdq_pair_to_weight, - ) - else: - q_weight_name, zp_name, scale_name = self.quantize_initializer( - weight_proto, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, - keep_float_weight=self.add_qdq_pair_to_weight, - ) + if weight_name in self.quantized_value_map: + return + quant_params: QuantizationParams = self.initializer_quant_params[weight_name] + axis: int = quant_params.get("axis") + scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params) + q_weight_name: str | None = None weight_dequant_output = add_dequant_output_suffix(weight_name) self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output) + if self.add_qdq_pair_to_weight: + # Don't actually quantize the weight. Instead, keep floating-point weight and create the node + # sequence (weight_f32 -> Q -> DQ -> weight_dequant) weight_quant_output = add_quant_output_suffix(weight_name) self._create_qdq_nodes( @@ -516,14 +689,26 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): weight_quant_output, weight_dequant_output, add_dequant_suffix(weight_name), - scale_name, - zp_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, axis, ) else: + # Quantize the weight and create the node sequence: + # (weight_quantized -> DQ -> weight_dequant) + quant_weight = quantize_onnx_initializer( + weight_proto, + quant_params["quant_type"], + quant_params["zero_point"], + quant_params["scale"], + axis, + ) + self.model.add_initializer(quant_weight) + + q_weight_name = quant_weight.name dequant_node = onnx.helper.make_node( DEQUANT_OP_NAME, - [q_weight_name, scale_name, zp_name], + [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name], [weight_dequant_output], add_dequant_suffix(weight_name), axis=axis, @@ -531,6 +716,17 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): ) self.model.add_node(dequant_node) + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, + QuantizedValueType.Initializer, + axis=axis, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None): if ( self.dedicated_qdq_pair @@ -767,8 +963,16 @@ def _quantize_normal_tensors(self): # Quantize the input initializer = find_by_name(tensor_name, self.model.initializer()) if initializer: - self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) + self._add_qdq_nodes_for_initializer(initializer) else: + # Check if this tensor is already a dequantized value. If so, skip it. + # This happens if the original input model already has some pre-quantized weights + # generated by a different tool. + # Ex: (quantized_weight -> DequantizeLinear -> this_tensor) + if tensor_name in self.tensor_to_producing_dq: + del self.tensors_to_quantize[tensor_name] + continue + tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) if not tensor_qparam_initializers: raise ValueError( @@ -820,6 +1024,12 @@ def _quantize_sharing_param_tensors(self): if self.is_input_a_initializer(tensor_name): raise ValueError("Quantization parameter shared mode is not supported for weight yet") + if tensor_name in self.tensor_to_producing_dq: + raise ValueError( + f"Quantization parameter sharing is invalid for tensor {tensor_name} " + "because it has already been quantized" + ) + # Need to check if this tensor's quant_type is converted for some consumers. # If so, create new scale/zp initializers for these consumers. converted_qparam_inits = None @@ -909,45 +1119,6 @@ def _quantize_bias_tensors(self): def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize - def quantize_initializer( - self, - weight: onnx.TensorProto, - qType: onnx.TensorProto.DataType, - reduce_range: bool = False, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - """ - :param weight: TensorProto initializer - :param qType: type to quantize to - :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. - If keep_float_weight is False, quantize the weight, or don't quantize the weight. - :return: quantized weight name, zero point name, scale name - """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - - q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( - weight, qType, reduce_range, keep_float_weight - ) - - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) - return q_weight_name, zp_name, scale_name - def is_tensor_per_channel( self, tensor_name: str, @@ -997,37 +1168,29 @@ def is_tensor_per_channel( return True, axis - def quantize_weight_per_channel( - self, - weight_name: str, - weight_qType: onnx.TensorProto.DataType, - channel_axis: int, - reduce_range: bool = True, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) + def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None: + """ + Returns the quantization scale of a tensor that is consumed by the given node. + :parameter tensor_name: The name of the tensor. + :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case + the quantization type of the tensor was converted. + Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation. + :returns: The quantization scale or None. + """ + initializers = self.model.initializer() + scale_initializer: onnx.TensorProto | None = None - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( - weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight - ) - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + if tensor_name in self.quantized_value_map: + # Tensor was quantized by this tool, so get scale from initializer created by this tool run. + scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name + scale_initializer = find_by_name(scale_name, initializers) + else: + # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor. + dq_node = self.tensor_to_producing_dq.get(tensor_name, None) + if dq_node: + scale_initializer = find_by_name(dq_node.input[1], initializers) - return q_weight_name, zp_name, scale_name + return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: """ @@ -1038,17 +1201,21 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s if bias_name in self.quantized_value_map: return self.quantized_value_map[bias_name].original.q_name - # get scale for weight - weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) + # get scale for weight. + weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name) + if weight_scale is None: + raise ValueError( + f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' " + f"when quantizing bias '{bias_name}' to int32." + ) - # get scale for input - input_scale_name = ( - self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name - ) - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) + # get scale for input. + input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name) + if input_scale is None: + raise ValueError( + f"Unable to get valid quantization scale for input '{bias_info.input_name}' " + f"when quantizing bias '{bias_name}' to int32." + ) ( quantized_bias_name, @@ -1074,7 +1241,7 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s return quantized_bias_name def _make_scale_zp_initializers( - self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = "" ) -> QDQScaleZpInitializers: """ Creates and returns scale and zero-point initializers for the given quantization params. The initializers are @@ -1082,31 +1249,31 @@ def _make_scale_zp_initializers( - {param_name}_zero_point{init_name_suffix} - {param_name}_scale{init_name_suffix} """ - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params.data.get("quant_type", self.activation_qType) - - zero_point_shape = [] + zero_point = quant_params["zero_point"] + scale = quant_params["scale"] + zero_point_type = quant_params["quant_type"] + axis: int | None = quant_params.get("axis") + assert (axis is not None and len(scale.shape) == 1) or ( + axis is None and len(scale.shape) == 0 + ), "Wrong scale/zp shapes" + assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank" + zero_point_name = param_name + "_zero_point" + init_name_suffix - scale_shape = [] scale_name = param_name + "_scale" + init_name_suffix # Add initializers to model init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist() ) self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: + if scale.dtype == np.float32: scale_type = onnx_proto.TensorProto.FLOAT - elif scale_values.dtype == np.float16: + elif scale.dtype == np.float16: scale_type = onnx_proto.TensorProto.FLOAT16 else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist()) self.model.add_initializer(init_scale) return QDQScaleZpInitializers(init_scale, init_zp) @@ -1155,7 +1322,7 @@ def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type) def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: """ @@ -1185,3 +1352,127 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) return quantization_params + + def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]: + """ + Returns quantization parameters (scale/zero_point/quant_type) for all initializers. + """ + + quantization_params: dict[str, QuantizationParams] = {} + for tensor_name, tensor_info in self.tensors_to_quantize.items(): + initializer = find_by_name(tensor_name, self.model.initializer()) + if not initializer: + continue + + initializer_data = tensor_proto_to_array(initializer) + initializer_rank = len(initializer_data.shape) + + # initializers for elementwise ops use the quant_type for activations. + is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT + quant_type = self.weight_qType if is_weight else self.activation_qType + + # Try to get scale/zp directly from user's overrides and avoid computation. + if self.tensor_quant_overrides.overrides_scale_zp(tensor_name): + overrides = self.tensor_quant_overrides[tensor_name] + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type] + is_per_channel = "axis" in overrides[0] + if not is_per_channel: + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype), + scale=np.array(overrides[0]["scale"], initializer_data.dtype), + quant_type=quant_type, + ) + else: + zero_points_list = [] + scales_list = [] + for chan_overrides in overrides: + zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype)) + scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype)) + + channel_axis = overrides[0]["axis"] + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(zero_points_list), + scale=np.array(scales_list), + quant_type=quant_type, + axis=norm_channel_axis, + ) + + continue + + # Compute scale/zp normally. User's overrides may still override parameters + # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.) + overrides = self.tensor_quant_overrides.get(tensor_name, [{}]) + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + channel_axis = overrides[0].get("axis", tensor_info.axis) + is_per_channel = channel_axis is not None + + # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the + # same zero-point in every channel, which is necessarily the case for symmetric quantization. + is_symmetric_default = is_per_channel or ( + self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric + ) + is_symmetric = overrides[0].get("symmetric", is_symmetric_default) + reduce_range = overrides[0].get("reduce_range", self.reduce_range) + zero_point: np.ndarray | None = None + scale: np.ndarray | None = None + + if not is_per_channel: + zero_point, scale = compute_data_quant_params( + initializer_data.flatten(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=overrides[0].get("rmin"), + rmax_override=overrides[0].get("rmax"), + ) + else: + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + channel_axis = norm_channel_axis + channel_count = initializer_data.shape[channel_axis] + zero_points_list = [] + scales_list = [] + for i in range(channel_count): + per_channel_data = initializer_data.take(i, channel_axis) + channel_overrides = overrides[i] if overrides and i < len(overrides) else {} + channel_zero_point, channel_scale = compute_data_quant_params( + per_channel_data.ravel(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=channel_overrides.get("rmin"), + rmax_override=channel_overrides.get("rmax"), + ) + zero_points_list.append(channel_zero_point) + scales_list.append(channel_scale) + + zero_point = np.asarray(zero_points_list) + scale = np.asarray(scales_list) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=zero_point, + scale=scale, + quant_type=quant_type, + axis=channel_axis, + ) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 9228ad33130f2..2bf675745d093 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -33,6 +33,12 @@ int4 = None uint4 = None +try: + from onnx.reference.op_run import to_array_extended +except ImportError: + # old version of onnx. + to_array_extended = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -43,6 +49,7 @@ DEQUANT_OP_NAME = "DequantizeLinear" DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output" TENSOR_NAME_QUANT_SUFFIX = "_quantized" +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB FLOAT8_DISTRIBUTIONS = {} @@ -156,7 +163,9 @@ def from_string(format): } ONNX_INT_TYPE_SYMMETRIC_RANGE = { + onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(254, dtype=numpy.uint8)), onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), + onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65534, dtype=numpy.uint16)), onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), } @@ -229,7 +238,7 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): # which matches the python reference ONNX implementation of QuantizeLinear. # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] - (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) + qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False) cliplow = max(qmin, low) if low is not None else qmin cliphigh = min(qmax, high) if high is not None else qmax @@ -269,7 +278,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non # Ensure a minimum float-point range if specified. if min_real_range is not None: - rmax = max(rmax, rmin + min_real_range) + rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype)) if symmetric: absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax)) @@ -338,13 +347,75 @@ def compute_scale_zp_float8(element_type, std): return [zero, scale] +def compute_data_quant_params( + data: numpy.ndarray, + quant_type: onnx.TensorProto.DataType, + symmetric: bool, + reduce_range: bool = False, + min_real_range: float | None = None, + rmin_override: float | None = None, + rmax_override: float | None = None, +) -> tuple[numpy.ndarray, numpy.ndarray]: + """ + Returns the zero_point and scale for the given data. + + :param data: The data for which to compute quantization parameters. + :param quant_type: The quantization data type. + :param symmetric: whether symmetric quantization is used or not. + :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. + :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. + :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). + :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data). + :return: zero point and scale + """ + if not isinstance(data, numpy.ndarray): + raise TypeError(f"Weight must be given as an array not {type(data)}.") + if rmin_override is not None: + rmin = rmin_override + else: + rmin = data.min() if len(data) else 0.0 + + if rmax_override is not None: + rmax = rmax_override + else: + rmax = data.max() if len(data) else 0.0 + + rmin = numpy.array(rmin, dtype=data.dtype) + rmax = numpy.array(rmax, dtype=data.dtype) + scale = numpy.array(1.0, dtype=data.dtype) + + if quant_type == TensorProto.FLOAT8E4M3FN: + if reduce_range: + raise RuntimeError("Unsupported option reduce_range=True for float 8.") + std = numpy.std(data) + zero_point, scale = compute_scale_zp_float8(quant_type, std) + return _check_type(zero_point, scale, zero_point_index=0) + + if quant_type in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.INT16, + TensorProto.UINT16, + TensorProto.INT4, + TensorProto.UINT4, + ): + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric) + if len(data): + zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) + else: + zero_point = numpy.array(0, dtype=qmin.dtype) + return _check_type(zero_point, scale, zero_point_index=0) + + raise ValueError(f"Unexpected value for quant_type={quant_type}.") + + def quantize_data( data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None -): +) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """ :param data: data to quantize - :param qType: data type to quantize to. Supported types UINT8 and INT8 - :param symmetric: whether symmetric quantization is used or not. This is applied to INT8. + :param qType: data type to quantize to. + :param symmetric: whether symmetric quantization is used or not. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). @@ -366,28 +437,16 @@ def quantize_data( - *S*: scale - *z*: zero point """ - if not isinstance(data, numpy.ndarray): - raise TypeError(f"Weight must be given as an array not {type(data)}.") - if rmin_override is not None: - rmin = rmin_override - else: - rmin = data.min() if len(data) else 0.0 - - if rmax_override is not None: - rmax = rmax_override - else: - rmax = data.max() if len(data) else 0.0 - - rmin = numpy.array(rmin, dtype=data.dtype) - rmax = numpy.array(rmax, dtype=data.dtype) - zero_point = 0 - scale = numpy.array(1.0, dtype=data.dtype) - + zero_point, scale = compute_data_quant_params( + data, + qType, + symmetric, + reduce_range, + min_real_range, + rmin_override, + rmax_override, + ) if qType == TensorProto.FLOAT8E4M3FN: - if reduce_range: - raise RuntimeError("Unsupported option reduce_range=True for float 8.") - std = numpy.std(data) - zero_point, scale = compute_scale_zp_float8(qType, std) quantized_data = quantize_nparray(qType, data, scale, zero_point) if any((quantized_data.astype(numpy.uint8).ravel() & 127) == 127): np_data = numpy.asarray(data) @@ -395,7 +454,7 @@ def quantize_data( f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], " f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]." ) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data if qType in ( TensorProto.INT8, @@ -405,15 +464,91 @@ def quantize_data( TensorProto.INT4, TensorProto.UINT4, ): - if len(data): - qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) - zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) quantized_data = quantize_nparray(qType, data, scale, zero_point) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data raise ValueError(f"Unexpected value for qType={qType}.") +def quantize_onnx_initializer( + weight: onnx.TensorProto, + quant_type: onnx.TensorProto.DataType, + zero_point: numpy.ndarray, + scale: numpy.ndarray, + axis: int | None = None, + quant_weight_name: str | None = None, +) -> onnx.TensorProto: + """ + Returns a quantized version of the given ONNX initializer. + + :param weight: The ONNX initializer to quantize. + :param quant_type: The final quantized data type. + :param zero_point: The zero-point value to use for quantization. + :param scale: The scale value to use for quantization. + :param axis: The quantization axis if quantizing per-channel. Defaults to None. + :param quant_weight_name: The name of the quantized initializer. + If not specified, the quantized name is generated. + :return: The quantized ONNX initializer. + """ + weight_data = tensor_proto_to_array(weight) + q_weight_data: numpy.ndarray | None = None + + if axis is None: # Per-tensor quantization + q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point) + else: # Per-channel quantization + channel_count = weight_data.shape[axis] + channel_dims = list(weight_data.shape) # deep copy + channel_dims[axis] = 1 # only one per channel for reshape + quantized_channel_data_list = [] + + for i in range(channel_count): + channel_data = weight_data.take(i, axis) + channel_scale = scale[i] + channel_zero_point = zero_point[i] + quantized_channel_data = quantize_nparray( + quant_type, channel_data.ravel(), channel_scale, channel_zero_point + ) + quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims)) + + q_weight_data = numpy.concatenate(quantized_channel_data_list, axis) + + q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}" + + if quant_type == onnx.TensorProto.FLOAT8E4M3FN: + q_weight_initializer = onnx.TensorProto() + q_weight_initializer.data_type = quant_type + q_weight_initializer.dims.extend(weight.dims) + q_weight_initializer.name = q_weight_name + # Do not remove .flatten().copy() numpy is not clear about data persistence. + q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes() + if to_array_extended is not None: + # This test should not be needed but it helped catch some issues + # with data persistence and tobytes. + check = to_array_extended(q_weight_initializer) + if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes(): + raise RuntimeError( + f"The initializer of shape {weight_data.shape} could not be created, expecting " + f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}" + f"\nraw={str(q_weight_initializer)[:200]}." + ) + elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + if q_weight_data.dtype not in (numpy.int8, numpy.uint8): + raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.") + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True) + else: + quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type) + q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims) + q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) + + return q_weight_initializer + + def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802 """ Return qmin and qmax, the minimum and maximum value representable by the given qType diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 745344dc01fcb..4ffd8b9872982 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -3,10 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + +import copy import logging import tempfile from pathlib import Path -from typing import Union +from typing import Any, Callable import onnx @@ -14,6 +17,7 @@ from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer from .quant_utils import ( + MODEL_SIZE_THRESHOLD, QuantFormat, QuantizationMode, QuantType, @@ -22,6 +26,7 @@ save_and_reload_model_with_shape_infer, ) from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantConfig: @@ -192,6 +197,9 @@ def __init__( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc. Raises: ValueError: Raise ValueError if execution provider is unknown @@ -213,6 +221,167 @@ def __init__( self.extra_options = extra_options or {} +def get_qdq_config( + model_input: str | Path | onnx.ModelProto, + calibration_data_reader: CalibrationDataReader, + calibrate_method=CalibrationMethod.MinMax, + calibrate_args: dict[str, Any] | None = None, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + activation_symmetric: bool = False, + weight_symmetric: bool | None = None, + per_channel: bool = False, + reduce_range: bool = False, + keep_removable_activations: bool = False, + min_real_range: float | None = None, + tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, + nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, + extra_options: dict | None = None, +) -> StaticQuantConfig: + """ + Returns a configuration suitable that quantizes the entire model to integer precision. + + Params: + model_input: Path to the input model file or ModelProto. + calibration_data_reader: Calibration data reader. + calibrate_methode: The calibration method. Defaults to MinMax. + activation_type: The default activation quantization type. Defaults to QUInt8. + weight_type: The default weight quantization type. Defaults to QInt8. + activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default. + Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16, + the zero-point values are 127 and 32,767, respectively. + weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default. + Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int. + per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel. + Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators + and their quantization axes. + reduce_range: quantize weights with 1 less bit of precision (e.g., 7 bits for QInt8). Defaults to false. + May improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode. + keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not + be removed, and will be explicitly represented in the QDQ model. If false, these activations + are automatically removed if activations are asymmetrically quantized. Keeping these activations + is necessary if optimizations or EP transformations will later remove + QuantizeLinear/DequantizeLinear operators from the model. + min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters + (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin) + is less than the specified minimum range, rmax will be set to rmin + min_real_range. + tensor_quant_overrides: tensor-level quantization overrides. Defaults to None. + The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list + contains a single dictionary. For per-channel quantization, the list contains either a dictionary for + each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis' + key must be present in the first dictionary for per-channel quantization. + + Each dictionary contains optional overrides with the following keys and values. + 'quant_type' = QuantType : The tensor's quantization data type. + 'axis' = Int : The per-channel axis. Must be present for per-channel weights. + 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. + 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. + 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also + set `scale` or `zero_point`. + 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also + set `scale` or `zero_point`. Only valid for initializers. + 'rmax' = Float : Override the maximum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'rmin' = Float : Override the minimum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'convert' = Dict : A nested dictionary with the same keys for an activation + tensor that should be converted to another quantization type. + 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, + other nodes get the original type. If not specified, + assume all consumer nodes get the converted type. + nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that + accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto + should be excluded from quantization. + extra_options: Additional options specified as string key/value pairs. Refer to the documentation for + `quantize_static` for valid keys and values. + + Returns: + A StaticQuantConfig object + """ + q16_types = {QuantType.QInt16, QuantType.QUInt16} + q4_types = {QuantType.QInt4, QuantType.QUInt4} + op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"} + + model = ( + model_input + if isinstance(model_input, onnx.ModelProto) + else onnx.load_model(model_input, load_external_data=False) + ) + + op_types = set() + model_has_external_data = False + overrides_helper = TensorQuantOverridesHelper( + copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {} + ) + + # check if the model has external data. + for initializer in model.graph.initializer: + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + + final_nodes_to_exclude = [] + if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): + final_nodes_to_exclude.extend(nodes_to_exclude) + + # Iterate through nodes to get all operator types in the model and + # call user's function to filter out nodes from quantization. + for node in model.graph.node: + op_types.add(node.op_type) + if nodes_to_exclude is not None and callable(nodes_to_exclude): + if nodes_to_exclude(model, node): + final_nodes_to_exclude.append(node.name) + + final_extra_options = { + "MinimumRealRange": min_real_range, + "QDQKeepRemovableActivations": keep_removable_activations, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, + "ForceQuantizeNoInputCheck": True, + "TensorQuantOverrides": overrides_helper.get_dict(), + } + + # Pass along known calibration options + if calibrate_args: + calib_extra_options_keys = [ + ("symmetric", "CalibTensorRangeSymmetric"), + ("moving_average", "CalibMovingAverage"), + ("averaging_constant", "CalibMovingAverageConstant"), + ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"), + ("percentile", "CalibPercentile"), + ] + calib_extra_options = { + key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args + } + final_extra_options.update(calib_extra_options) + + # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain + # on Q/DQ operators if using 16-bit or 4-bit quantization. + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") + if onnx_opset.version < 21: + opset21_types = q16_types.union(q4_types) + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: + final_extra_options["UseQDQContribOps"] = True + + # Allow user's extra_options to override our final_extra_options. + if extra_options: + final_extra_options.update(extra_options) + + return StaticQuantConfig( + calibration_data_reader, + calibrate_method=calibrate_method, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), + nodes_to_exclude=final_nodes_to_exclude, + per_channel=per_channel, + reduce_range=reduce_range, + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), + extra_options=final_extra_options, + ) + + class DynamicQuantConfig(QuantConfig): def __init__( self, @@ -290,8 +459,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, calibration_data_reader: CalibrationDataReader, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, @@ -438,6 +607,9 @@ def quantize_static( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. """ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN: if calibrate_method != CalibrationMethod.Distribution: @@ -473,6 +645,7 @@ def quantize_static( ("CalibMovingAverage", "moving_average"), ("CalibMovingAverageConstant", "averaging_constant"), ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"), + ("CalibPercentile", "percentile"), ] calib_extra_options = { key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options @@ -590,8 +763,8 @@ def inc_dataloader(): def quantize_dynamic( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, op_types_to_quantize=None, per_channel=False, reduce_range=False, @@ -690,8 +863,8 @@ def quantize_dynamic( def quantize( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, quant_config: QuantConfig, ): """Quantize a model with QuantConfig. diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 160b056e1de17..fbeae39c39d21 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -14,7 +14,7 @@ from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.norm import QDQNormalization -from .operators.pad import QPad +from .operators.pad import QDQPad, QPad from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize @@ -76,6 +76,8 @@ "Resize": QDQResize, "MaxPool": QDQMaxPool, "AveragePool": QDQDirect8BitOp, + "Slice": QDQDirect8BitOp, + "Pad": QDQPad, "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 219d929d22fce..fbd0cc17f5d81 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -78,6 +78,10 @@ def has_per_channel_overrides(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and "axis" in overrides_list[0] + def overrides_scale_zp(self, tensor_name: str) -> bool: + overrides_list = self.overrides.get(tensor_name) + return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0]) + def get_per_tensor_overrides( self, tensor_name: str, diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index 763d160fa56b5..3ebc33c02592d 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -17,8 +17,8 @@ TRT_DOCKER_FILES = { "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6", - "10.4.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", - "10.4.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", + "10.5.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", + "10.5.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin", } diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 26f8987c76623..55ce8d752a9d6 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -348,7 +348,7 @@ def run_pytorch( else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length logger.debug(f"Model {model}") logger.debug(f"Number of parameters {model.num_parameters()}") @@ -500,7 +500,7 @@ def run_tensorflow( tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length for batch_size in batch_sizes: if batch_size <= 0: diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 167fc8697ce06..ccf2497d61342 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -250,6 +250,7 @@ def generate_test_data( average_sequence_length: int, random_sequence_length: bool, mask_type: int, + dictionary_size: int = 10000, ): """Create given number of input data for testing @@ -270,7 +271,6 @@ def generate_test_data( List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary with input name as key and a tensor as value """ - dictionary_size = 10000 all_inputs = fake_test_data( batch_size, sequence_length, diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 0c5125e74c8a4..03bcc20d9a5de 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -85,6 +85,7 @@ def run_test( segment_ids_name, input_mask_name, mask_type, + dictionary_size: int = 1024, ): # Try deduce input names from optimized model. input_ids, segment_ids, input_mask = get_bert_inputs( @@ -105,6 +106,7 @@ def run_test( average_sequence_length, True, # random sequence length mask_type, + dictionary_size=dictionary_size, ) baseline_results, baseline_latency, output_names = run_model( diff --git a/onnxruntime/python/tools/transformers/dev_benchmark.cmd b/onnxruntime/python/tools/transformers/dev_benchmark.cmd index 82137de3c0f3b..4bef58621e8c0 100644 --- a/onnxruntime/python/tools/transformers/dev_benchmark.cmd +++ b/onnxruntime/python/tools/transformers/dev_benchmark.cmd @@ -3,9 +3,7 @@ REM Run benchmark in Windows for developing purpose. For official benchmark, please use run_benchmark.sh. REM Settings are different from run_benchmark.sh: no cli, batch and sequence, input counts, average over 100, no fp16, less models etc. -REM Please install PyTorch (see https://pytorch.org/) before running this benchmark. Like the following: -REM GPU: conda install pytorch torchvision cudatoolkit=10.1 -c pytorch -REM CPU: conda install pytorch torchvision cpuonly -c pytorch +REM Please install PyTorch (see https://pytorch.org/) before running this benchmark. REM When use_package=true, you need not copy other files to run benchmarks except this sh file. REM Otherwise, it will use python script (*.py) files in this directory. @@ -21,12 +19,12 @@ set run_torchscript=false REM Devices to test. REM Attention: You cannot run both CPU and GPU at the same time: gpu need onnxruntime-gpu, and CPU need onnxruntime. -set run_gpu_fp32=false -set run_gpu_fp16=false -set run_cpu_fp32=true -set run_cpu_int8=true +set run_gpu_fp32=true +set run_gpu_fp16=true +set run_cpu_fp32=false +set run_cpu_int8=false -set average_over=100 +set average_over=1000 REM Enable optimizer (use script instead of OnnxRuntime for graph optimization) set use_optimizer=true @@ -36,7 +34,7 @@ set sequence_length=8 128 REM Number of inputs (input_ids, token_type_ids, attention_mask) for ONNX model. REM Note that different input count might lead to different performance -set input_counts=1 +set input_counts=3 REM Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased set models_to_test=bert-base-cased @@ -57,7 +55,6 @@ if %run_cpu_int8% == true if %run_gpu_fp32% == true echo cannot test cpu and gpu if %run_cpu_int8% == true if %run_gpu_fp16% == true echo cannot test cpu and gpu at same time & goto :EOF if %run_install% == true ( - pip uninstall --yes ort_nightly pip uninstall --yes onnxruntime pip uninstall --yes onnxruntime-gpu if %run_cpu_fp32% == true ( @@ -70,7 +67,6 @@ if %run_install% == true ( ) ) - pip install --upgrade onnxconverter_common pip install --upgrade transformers ) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..74adc951c4aa3 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -132,6 +132,7 @@ def make_value_info_from_tensor(tensor): "Scaler", "TreeEnsembleClassifier", "TreeEnsembleRegressor", + "TreeEnsemble", "ZipMap", "NonMaxSuppression", "TopK", diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index a9ff623fb6967..030708783bb61 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -42,26 +42,26 @@ def get_first_mask(self): assert len(self.mask_indice) > 0 return next(iter(self.mask_indice)) - def process_mask(self, input: str) -> str: + def process_mask(self, mask_2d: str) -> Optional[str]: if self.mask_format == AttentionMaskFormat.NoMask: return None - if input in self.mask_indice: - return self.mask_indice[input] + if mask_2d in self.mask_indice: + return self.mask_indice[mask_2d] # Add cast to convert int64 to int32 - if self.model.find_graph_input(input): - casted, input_name = self.utils.cast_graph_input_to_int32(input) + if self.model.find_graph_input(mask_2d): + casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d) else: - input_name, cast_node = self.utils.cast_input_to_int32(input) + input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d) casted = True if casted: - self.mask_casted[input] = input_name + self.mask_casted[mask_2d] = input_name # Attention supports int32 attention mask (2D) since 1.4.0 if self.mask_format == AttentionMaskFormat.AttentionMask: - self.mask_indice[input] = input_name + self.mask_indice[mask_2d] = input_name return input_name # Add a mask processing node to convert attention mask to mask index (1D) @@ -97,7 +97,7 @@ def process_mask(self, input: str) -> str: self.model.add_node(mask_index_node) - self.mask_indice[input] = output_name + self.mask_indice[mask_2d] = output_name return output_name @@ -173,17 +173,20 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] Tuple[int, int]: num_heads and hidden_size """ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] - q_shape = self.model.get_initializer(reshape_q.input[1]) - if q_shape is None: + q_shape_value = self.model.get_constant_value(reshape_q.input[1]) + if q_shape_value is None: concat = self.model.get_parent(reshape_q, 1) if concat is not None and concat.op_type == "Concat": return self.get_num_heads_and_hidden_size_from_concat(concat) - logger.debug(f"{reshape_q.input[1]} is not initializer.") + logger.debug("%s is not initializer.", reshape_q.input[1]) return self.num_heads, self.hidden_size # Fall back to user specified value - q_shape_value = NumpyHelper.to_array(q_shape) - if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0): - logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].") + if ( + (not isinstance(q_shape_value, np.ndarray)) + or len(q_shape_value) != 4 + or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0) + ): + logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value) return self.num_heads, self.hidden_size # Fall back to user specified value num_heads = q_shape_value[2] @@ -192,13 +195,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: - logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + logger.warning( + "--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads + ) self.num_heads_warning = False # Do not show the warning more than once if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( - f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + "--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size ) self.hidden_size_warning = False # Do not show the warning more than once @@ -216,11 +221,11 @@ def get_add_qk_str(self, add_qk: NodeProto): input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: - logger.debug(f"one of the inputs of {add_qk} is None") + logger.debug("one of the inputs of %s is None", add_qk) return None if input_0_shape != input_1_shape: - logger.debug(f"the shape of two inputs of {add_qk} is not same") + logger.debug("the shape of two inputs of %s is not same", add_qk) return None return add_qk.input[1] @@ -305,55 +310,6 @@ def concat_kv(self, past_k: str, past_v: str) -> str: return kv_output_name - def reshape_kv(self, past_k: str, past_v: str) -> (str, str): - """Reshape past_k and past_v from 4D to 3D to use as inputs for multihead attention node. - - Args: - past_k (str): name of past K value of shape 4D - past_v (str): name of past V value of shape 4D - - Returns: - k_3d (str): name of past K value of shape 3D - v_3d (str): name of past V value of shape 3D - """ - # Reshape past_k and past_v from (B,N,P,H) to (B,P,N*H) - # B = batch size, N = num heads, P = past seq len, H = head size - - # Create initializer for reshaping past_k and past_v - new_dims_name = "kv_4d_to_3d" - new_dims = self.model.get_initializer(new_dims_name) - if new_dims is None: - new_dims = numpy_helper.from_array( - np.array([0, -1, self.model.hidden_size], dtype="int64"), name=new_dims_name - ) - self.model.add_initializer(new_dims, self.this_graph_name) - - reshape_k_name = self.model.create_node_name("Reshape") - reshape_v_name = self.model.create_node_name("Reshape") - k_3d_name = (past_k + "_3d").replace(".", "_") - v_3d_name = (past_v + "_3d").replace(".", "_") - - k_3d = helper.make_node( - "Reshape", - inputs=[past_k, new_dims_name], - outputs=[k_3d_name], - name=reshape_k_name, - ) - v_3d = helper.make_node( - "Reshape", - inputs=[past_v, new_dims_name], - outputs=[v_3d_name], - name=reshape_v_name, - ) - - # Add reshape nodes to graph - self.nodes_to_add.append(k_3d) - self.nodes_to_add.append(v_3d) - self.node_name_to_graph_name[reshape_k_name] = self.this_graph_name - self.node_name_to_graph_name[reshape_v_name] = self.this_graph_name - - return k_3d_name, v_3d_name - def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str): """Split kv_node containing present KV values into separate present K and present V values. @@ -476,8 +432,7 @@ def create_packed_qkv_matmul_node( q_add: NodeProto, k_add: Union[NodeProto, None], v_add: Union[NodeProto, None], - num_heads: int, - ) -> Union[NodeProto, None]: + ) -> Tuple[NodeProto, NodeProto, NodeProto]: """Create packed QKV MatMul node before MultiHeadAttention node. This is for the scenario where an Attention node should be created but cannot be created because past_key and past_value are separate inputs and not one concatenated input. @@ -489,10 +444,11 @@ def create_packed_qkv_matmul_node( q_add (NodeProto): name of Add from Q path k_add (NodeProto): name of Add from K path v_add (NodeProto): name of Add from V path - num_heads (int): number of heads Returns: - Union[NodeProto, None]: the node created or None if failed. + q_output (NodeProto): Slice node for Q + k_output (NodeProto): Slice node for K + v_output (NodeProto): Slice node for V """ matmul_node_name = self.model.create_node_name("MatMul") @@ -611,6 +567,7 @@ def create_packed_qkv_matmul_node( self.nodes_to_add.extend(qkv_nodes) return q_output, k_output, v_output + # This function is used in child classes for bart or conformer model. def create_multihead_attention_node( self, q_matmul: NodeProto, @@ -659,7 +616,7 @@ def create_multihead_attention_node( assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: - logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) return None graph_input_names = set([node.name for node in self.model.graph().input]) @@ -669,17 +626,22 @@ def create_multihead_attention_node( mha_inputs = [] if packed_qkv: q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node( - q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads + q_matmul, + k_matmul, + v_matmul, + q_add, + k_add, + v_add, ) mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]]) - elif type(k_matmul) is NodeProto and type(v_matmul) is NodeProto: + elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto): if self.disable_multi_head_attention_bias: mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]]) else: mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]]) elif ( - type(k_matmul) == str # noqa: E721 - and type(v_matmul) == str # noqa: E721 + isinstance(k_matmul, str) + and isinstance(v_matmul, str) and k_matmul in graph_input_names and v_matmul in graph_input_names ): @@ -724,7 +686,7 @@ def create_multihead_attention_node( def create_attention_node( self, - mask_index: str, + mask_index: Optional[str], q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, @@ -733,7 +695,7 @@ def create_attention_node( v_add: NodeProto, num_heads: int, hidden_size: int, - input: str, + first_input: str, output: str, add_qk_str: str = "", past_k: str = "", @@ -746,7 +708,7 @@ def create_attention_node( """Create an Attention node. Args: - mask_index (str): mask input + mask_index (str | None): mask input q_matmul (NodeProto): MatMul node in fully connection for Q k_matmul (NodeProto): MatMul node in fully connection for K v_matmul (NodeProto): MatMul node in fully connection for V @@ -755,7 +717,7 @@ def create_attention_node( v_add (NodeProto): Add bias node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. - input (str): input name + first_input (str): first input name output (str): output name add_qk_str (str): name of Add node after Q x K' past_k (str): name of input for past K value @@ -771,7 +733,7 @@ def create_attention_node( assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: - logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) return None has_bias = True @@ -813,8 +775,10 @@ def create_attention_node( if hidden_size > 0 and hidden_size != qw_in_size: logger.warning( - f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). " - "Please provide a correct input hidden size or pass in 0" + "Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). " + "Please provide a correct input hidden size or pass in 0", + hidden_size, + qw_in_size, ) is_qkv_diff_dims = False @@ -836,6 +800,8 @@ def create_attention_node( qkv_weight = np.stack((qw, kw, vw), axis=1) qkv_weight_dim = 3 * qw_out_size + qkv_bias_dim = 0 + qkv_bias: Optional[np.ndarray] = None if has_bias: qb = NumpyHelper.to_array(q_bias) kb = NumpyHelper.to_array(k_bias) @@ -861,7 +827,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=q_weight.data_type, - dims=[qw_in_size, qkv_weight_dim], + dims=[qw_in_size, int(qkv_weight_dim)], vals=qkv_weight, ) @@ -869,7 +835,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=q_bias.data_type, - dims=[qkv_bias_dim], + dims=[int(qkv_bias_dim)], vals=qkv_bias, ) @@ -897,7 +863,7 @@ def create_attention_node( ) else: attention_inputs = [ - input, + first_input, attention_node_name + "_qkv_weight", attention_node_name + "_qkv_bias" if has_bias else "", ] @@ -911,7 +877,7 @@ def create_attention_node( past_kv = self.concat_kv(past_k, past_v) attention_inputs.append(past_kv) - if add_qk_str is not None: + if add_qk_str: mask_output_name = self.reshape_add_qk(add_qk_str) # Add attention mask to attention node @@ -951,9 +917,10 @@ def create_attention_node( return attention_node - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + def fuse(self, node, input_name_to_nodes, output_name_to_node): # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern + normalize_node = node start_node = normalize_node if normalize_node.op_type == "LayerNormalization": add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) @@ -982,25 +949,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return other_inputs = [] - for _i, input in enumerate(start_node.input): - if input not in output_name_to_node: + for _i, node_input in enumerate(start_node.input): + if node_input not in output_name_to_node: continue - if input == qkv_nodes[0].output[0]: + if node_input == qkv_nodes[0].output[0]: continue - other_inputs.append(input) + other_inputs.append(node_input) if len(other_inputs) != 1: return root_input = other_inputs[0] - """ - Match flaubert Mask - | - Mul --> LayerNormalization --> Attention --> MatMul --> Add - | | - | | - +--------------------------------------------------------- - """ + + # Match flaubert Mask + # | + # Mul --> LayerNormalization --> Attention --> MatMul --> Add + # | | + # | | + # +--------------------------------------------------------- mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0) if mul_before_layernorm is not None: mul_children = input_name_to_nodes[mul_before_layernorm.output[0]] @@ -1020,19 +986,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if child.op_type == "LayerNormalization": root_input = child.output[0] - """ - When Add before the LayerNormalization produces an output - that is consumed by some other nodes other than the LayerNormalization itself, - fused SkipLayerNormalization will have several outputs. - In this case we need to pick the one used in Attention - - For example, this is the case for ViT - - SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization - | | - | | - +---------------------------------------------------------------------+ - """ + # When Add before the LayerNormalization produces an output + # that is consumed by some other nodes other than the LayerNormalization itself, + # fused SkipLayerNormalization will have several outputs. + # In this case we need to pick the one used in Attention + # For example, this is the case for ViT + # SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization + # | | + # | | + # +---------------------------------------------------------------------+ parent_node = output_name_to_node[root_input] if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: root_input = parent_node.output[0] @@ -1051,12 +1013,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): is_distill = False is_distill_add = False is_no_mask_attention = False + is_sdpa = False qk_paths = { "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]), "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]), "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]), "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]), "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]), + "sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]), } qk_nodes = None @@ -1066,10 +1030,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): continue if k == "path3": is_distill = True - if k == "path4": + elif k == "path4": is_distill_add = True - if k == "path5": + elif k == "path5": is_no_mask_attention = True + elif k == "sdpa": + is_sdpa = True break if qk_nodes is None: @@ -1079,19 +1045,23 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_qk = None matmul_qk = None where_qk = None + after_q = None if is_distill: (_, where_qk, matmul_qk, _) = qk_nodes elif is_distill_add: (_, add_qk, where_qk, matmul_qk) = qk_nodes elif is_no_mask_attention: (_, _, matmul_qk) = qk_nodes + elif is_sdpa: + (_, add_qk, matmul_qk, after_q, _) = qk_nodes else: (_, add_qk, _, matmul_qk) = qk_nodes - q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]) + after_q = after_q or matmul_qk + q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]) if q_nodes is None: q_nodes = self.model.match_parent_path( - matmul_qk, + after_q, ["Div", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None], ) @@ -1102,7 +1072,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_q = q_nodes[-2] matmul_q = q_nodes[-1] - k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) + after_k = matmul_qk + if is_sdpa: + mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None]) + if mul_k_nodes is None: + logger.debug("fuse_attention: failed to match mul sqrt q path") + return + (after_k, _) = mul_k_nodes + + k_nodes = self.model.match_parent_path( + after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None] + ) if k_nodes is None: k_nodes = self.model.match_parent_path( matmul_qk, @@ -1117,7 +1097,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Note that Cast might be removed by OnnxRuntime so we match two patterns here. mask_nodes = None - add_qk_str = None + add_qk_str = "" if is_distill: _, mask_nodes, _ = self.model.match_parent_paths( where_qk, @@ -1140,7 +1120,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if add_qk is not None: add_qk_str = self.get_add_qk_str(add_qk) if add_qk_str is None: - logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}") + logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk) return elif is_no_mask_attention: pass @@ -1148,11 +1128,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): _, mask_nodes, _ = self.model.match_parent_paths( add_qk, [ - ( - ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], - [None, 0, 1, 0, 0], - ), + (["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]), (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]), + # The following two patterns are for SDPA. + (["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]), + (["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]), ], output_name_to_node, ) @@ -1160,10 +1140,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: failed to match mask path") return - if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": + if not is_no_mask_attention and len(mask_nodes) > 1: _, mul_val = self.model.get_constant_input(mask_nodes[0]) - if mul_val != -10000: - self.mask_filter_value = mul_val + # The mask value shall be a float scalar (usually is the lowest float value). + if ( + (mul_val is None) + or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1) + or (float(mul_val) >= 0) + ): + return + if float(mul_val) != -10000: + self.mask_filter_value = float(mul_val) if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None @@ -1181,19 +1168,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately new_node = self.create_attention_node( - mask_index, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - q_num_heads, - q_hidden_size, - root_input, - attention_last_node.output[0], - add_qk_str, + mask_index=mask_index, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=q_num_heads, + hidden_size=q_hidden_size, + first_input=root_input, + output=attention_last_node.output[0], + add_qk_str=add_qk_str, ) + if new_node is None: return @@ -1208,7 +1196,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): name="shape_modified_tensor" + unique_index, data_type=TensorProto.INT64, dims=[4], - vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]), + vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)], raw=False, ) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index b027957fcc725..16e2c36bfd092 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -239,9 +239,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_add=add_v, num_heads=num_heads, hidden_size=hidden_size, - input=root_input, + first_input=root_input, output=attention_last_node.output[0], - add_qk_str=None, + add_qk_str="", scale=None, causal=(add_mask is not None), ) diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index ebecc1db24792..8c334b83abfeb 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -564,15 +564,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( - matmul_q, - matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, - matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, - add_q, - add_k if decoder_cross_attention or decoder_attention_with_past else None, - add_v if decoder_cross_attention or decoder_attention_with_past else None, - num_heads, - hidden_size, - attention_last_node.output[0], + q_matmul=matmul_q, + k_matmul=matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, + v_matmul=matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + q_add=add_q, + k_add=add_k if decoder_cross_attention or decoder_attention_with_past else None, + v_add=add_v if decoder_cross_attention or decoder_attention_with_past else None, + num_heads=num_heads, + hidden_size=hidden_size, + output=attention_last_node.output[0], past_k=past_k if decoder_attention_with_past else "", past_v=past_v if decoder_attention_with_past else "", present_k=present_k, @@ -586,19 +586,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Temporarily set multihead attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False + add_qk_str = mask_index if decoder_attention and mask_index else "" new_node = self.create_attention_node( - None, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - num_heads, - hidden_size, - root_input, - attention_last_node.output[0], - add_qk_str=mask_index if decoder_attention else None, + mask_index=None, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + first_input=root_input, + output=attention_last_node.output[0], + add_qk_str=add_qk_str, past_k=past_k, past_v=past_v, present_k=present_k, diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 6bc681c57444e..f29d0a0ac9441 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -102,15 +102,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return new_node = self.create_multihead_attention_node( - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - num_heads, - hidden_size, - attention_last_node.output[0], + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + output=attention_last_node.output[0], add_qk=add_qk.input[1], past_k=past_k, past_v=past_v, diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index dcfe4a28ad9af..4cd878a4656a7 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -13,155 +13,62 @@ "AutoModelForCausalLM", ] -# List of pretrained models: https://huggingface.co/transformers/pretrained_models.html # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type +# Some models like GPT, T5, Bart etc has its own convert_to_onnx.py in models sub-directory, and they are excluded here. MODELS = { # BERT - "bert-base-uncased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - "bert-large-uncased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - "bert-base-cased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", - # "token_type_ids"], 12, False, "bert"), - # "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", - # "token_type_ids"], 12, False, "bert"), - # "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # todo: more models to add - # GPT (no past state) - "openai-gpt": (["input_ids"], 11, False, "gpt2"), - # GPT-2 (no past state, use benchmark_gpt2.py for past_key_values) - "gpt2": (["input_ids"], 11, False, "gpt2"), - "gpt2-medium": (["input_ids"], 11, False, "gpt2"), - "gpt2-large": (["input_ids"], 11, True, "gpt2"), - "gpt2-xl": (["input_ids"], 11, True, "gpt2"), - "distilgpt2": (["input_ids"], 11, False, "gpt2"), - # Transformer-XL (Models uses Einsum, which need opset version 12 or later.) - "transfo-xl-wt103": (["input_ids", "mems"], 12, False, "bert"), + "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"), + "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"), + # Transformer-XL (Models uses Einsum, which need opset version 16 or later.) + "transfo-xl-wt103": (["input_ids", "mems"], 16, False, "bert"), # XLNet - "xlnet-base-cased": (["input_ids"], 12, False, "bert"), - "xlnet-large-cased": (["input_ids"], 12, False, "bert"), + "xlnet-base-cased": (["input_ids"], 16, False, "bert"), + "xlnet-large-cased": (["input_ids"], 16, False, "bert"), # XLM - "xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"), - "xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"), - "xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"), + "xlm-mlm-en-2048": (["input_ids"], 16, True, "bert"), + "xlm-mlm-ende-1024": (["input_ids"], 16, False, "bert"), + "xlm-mlm-enfr-1024": (["input_ids"], 16, False, "bert"), # RoBERTa - "roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), - "roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"), - "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"), - "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), + "roberta-base": (["input_ids", "attention_mask"], 16, False, "bert"), + "roberta-large": (["input_ids", "attention_mask"], 16, False, "bert"), + "roberta-large-mnli": (["input_ids", "attention_mask"], 16, False, "bert"), + "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 16, False, "bert"), + "distilroberta-base": (["input_ids", "attention_mask"], 16, False, "bert"), # DistilBERT - "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilbert-base-uncased-distilled-squad": ( - ["input_ids", "attention_mask"], - 11, - False, - "bert", - ), + "distilbert-base-uncased": (["input_ids", "attention_mask"], 16, False, "bert"), + "distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 16, False, "bert"), # CTRL - "ctrl": (["input_ids"], 11, True, "bert"), + "ctrl": (["input_ids"], 16, True, "bert"), # CamemBERT - "camembert-base": (["input_ids"], 11, False, "bert"), + "camembert-base": (["input_ids"], 16, False, "bert"), # ALBERT - "albert-base-v1": (["input_ids"], 12, False, "bert"), - "albert-large-v1": (["input_ids"], 12, False, "bert"), - "albert-xlarge-v1": (["input_ids"], 12, True, "bert"), - # "albert-xxlarge-v1": (["input_ids"], 12, True, "bert"), - "albert-base-v2": (["input_ids"], 12, False, "bert"), - "albert-large-v2": (["input_ids"], 12, False, "bert"), - "albert-xlarge-v2": (["input_ids"], 12, True, "bert"), - # "albert-xxlarge-v2": (["input_ids"], 12, True, "bert"), - # T5 (use benchmark_t5.py instead) - # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"), - # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"), - # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"), + "albert-base-v1": (["input_ids"], 16, False, "bert"), + "albert-large-v1": (["input_ids"], 16, False, "bert"), + "albert-xlarge-v1": (["input_ids"], 16, True, "bert"), + # "albert-xxlarge-v1": (["input_ids"], 16, True, "bert"), + "albert-base-v2": (["input_ids"], 16, False, "bert"), + "albert-large-v2": (["input_ids"], 16, False, "bert"), + "albert-xlarge-v2": (["input_ids"], 16, True, "bert"), + # "albert-xxlarge-v2": (["input_ids"], 16, True, "bert"), # XLM-RoBERTa - "xlm-roberta-base": (["input_ids"], 11, False, "bert"), - "xlm-roberta-large": (["input_ids"], 11, True, "bert"), + "xlm-roberta-base": (["input_ids"], 16, False, "bert"), + "xlm-roberta-large": (["input_ids"], 16, True, "bert"), # FlauBERT - "flaubert/flaubert_small_cased": (["input_ids"], 11, False, "bert"), - # "flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"), - "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"), - # "flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"), - # Bart - "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"), - # DialoGPT - "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"), - "microsoft/DialoGPT-medium": (["input_ids"], 11, False, "gpt2"), - # "microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"), - # Reformer - # "google/reformer-enwik8": (["input_ids"], 11, False, "bert"), - # "google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"), - # MarianMT - # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), - # Longformer (use benchmark_longformer.py instead) - # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), - # "allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"), - # MBart - "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"), - "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"), - # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), - # # Longformer - # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), - # "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/small": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/medium": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/large": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"), + "flaubert/flaubert_small_cased": (["input_ids"], 16, False, "bert"), + "flaubert/flaubert_base_cased": (["input_ids"], 16, False, "bert"), + # "flaubert/flaubert_large_cased": (["input_ids"], 16, False, "bert"), # Layoutlm - "microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"), - "microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"), + "microsoft/layoutlm-base-uncased": (["input_ids"], 16, False, "bert"), + "microsoft/layoutlm-large-uncased": (["input_ids"], 16, False, "bert"), # Squeezebert - "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"), - "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"), - "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"), - "unc-nlp/lxmert-base-uncased": ( - ["input_ids", "visual_feats", "visual_pos"], - 11, - False, - "bert", - ), - # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), - # "google/pegasus-large": (["input_ids"], 11, False, "bert"), + "squeezebert/squeezebert-uncased": (["input_ids"], 16, False, "bert"), + "squeezebert/squeezebert-mnli": (["input_ids"], 16, False, "bert"), + "squeezebert/squeezebert-mnli-headless": (["input_ids"], 16, False, "bert"), + "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 16, False, "bert"), # ViT - "google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"), + "google/vit-base-patch16-224": (["pixel_values"], 16, False, "vit"), # Swin - "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 12, False, "swin"), - "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 12, False, "swin"), - "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 12, False, "swin"), + "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 16, False, "swin"), + "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 16, False, "swin"), + "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 16, False, "swin"), } diff --git a/onnxruntime/python/tools/transformers/machine_info.py b/onnxruntime/python/tools/transformers/machine_info.py index 288e36facb708..d4194abbd14d3 100644 --- a/onnxruntime/python/tools/transformers/machine_info.py +++ b/onnxruntime/python/tools/transformers/machine_info.py @@ -129,8 +129,6 @@ def get_related_packages(self) -> List[str]: related_packages = [ "onnxruntime-gpu", "onnxruntime", - "ort-nightly-gpu", - "ort-nightly", "onnx", "transformers", "protobuf", diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 2433ae3d9b5ee..af5afc54e5d56 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -249,11 +249,7 @@ def save_results(results, filename): installed_packages = pkg_resources.working_set installed_packages_list = sorted( - [ - f"{i.key}=={i.version}" - for i in installed_packages - if i.key in ["ort-nightly-gpu", "ort-nightly", "onnxruntime", "onnxruntime-gpu"] - ] + [f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]] ) ort_pkg_name = "" diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index 26385896aa49b..e7cafeffc6231 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -89,13 +89,15 @@ It is able to run demo on optimized model as well. For example, python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp16 --use_gpu --demo ``` -## Benchmark +## Benchmark and Profiling We can create a conda environment then run GPU benchmark like the following: ```bash conda create -n sam2_gpu python=3.11 -y conda activate sam2_gpu -bash benchmark_sam2.sh $HOME gpu +install_dir=$HOME +profiling=true +bash benchmark_sam2.sh $install_dir gpu $profiling ``` or create a new conda environment for CPU benchmark: @@ -107,13 +109,14 @@ bash benchmark_sam2.sh $HOME cpu The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark. The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark. +The third parameter is optional. Value "true" will enable profiling after running benchmarking on GPU. The script will automatically install required packages in current conda environment, download checkpoints, export onnx, -and run demo, benchmark and profiling. +and run demo, benchmark and optionally run profiling. * The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel. * The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png. -* The profiling results are in *.nsys-rep or *.json files in current directory. +* The profiling results are in *.nsys-rep or *.json files in current directory. Use Nvidia NSight System to view the *.nsys-rep file. ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index 7e108b1638546..f75a4527be57d 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -90,26 +90,23 @@ def shape_dict(self) -> Mapping[str, List[int]]: else: return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks) - def random_inputs(self): + def random_inputs(self) -> Mapping[str, torch.Tensor]: + dtype = self.dtype if self.component == "image_encoder": - return { - "image": torch.randn( - self.batch_size, 3, self.height, self.width, dtype=torch.float32, device=self.device - ) - } + return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)} else: return { - "image_features_0": torch.rand(1, 32, 256, 256, dtype=torch.float32, device=self.device), - "image_features_1": torch.rand(1, 64, 128, 128, dtype=torch.float32, device=self.device), - "image_embeddings": torch.rand(1, 256, 64, 64, dtype=torch.float32, device=self.device), + "image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device), + "image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device), + "image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device), "point_coords": torch.randint( - 0, 1024, (self.num_labels, self.num_points, 2), dtype=torch.float32, device=self.device + 0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device ), "point_labels": torch.randint( 0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device ), - "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=torch.float32, device=self.device), - "has_input_masks": torch.ones(self.num_labels, dtype=torch.float32, device=self.device), + "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device), + "has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device), "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device), } @@ -314,7 +311,7 @@ def run_test( width=args.width, device=device, use_tf32=True, - enable_cuda_graph=False, + enable_cuda_graph=enable_cuda_graph, dtype=dtypes[args.dtype], prefer_nhwc=args.prefer_nhwc, repeats=args.repeats, diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index f8c5abdb75311..9e97867657ab9 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -1,48 +1,35 @@ #!/bin/bash # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -------------------------------------------------------------------------- - -# Here assumes that we are using conda (Anaconda/Miniconda/Miniforge) environment. -# For example, you can create a new conda environment like the following before running this script: -# conda create -n sam2_gpu python=3.11 -y -# conda activate sam2_gpu -# bash benchmark_sam2.sh $HOME gpu -# Or create a new conda environment for CPU benchmark: -# conda create -n sam2_cpu python=3.11 -y -# conda activate sam2_cpu -# bash benchmark_sam2.sh $HOME cpu - -python=$CONDA_PREFIX/bin/python3 - -# Directory of the script -dir="$( cd "$( dirname "$0" )" && pwd )" - -# Directory of the onnx models -onnx_dir=$dir/sam2_onnx_models - -# Directory to install CUDA, cuDNN, and git clone sam2 or onnxruntime source code. -install_dir=$HOME -if [ $# -ge 1 ]; then - install_dir=$1 -fi +# ------------------------------------------------------------------------- + +# Please refer to README.md for the prerequisites and usage of this script. +# bash benchmark_sam2.sh [profiling] + +python="$CONDA_PREFIX/bin/python3" + +# Directory of the script and ONNX models +dir="$(cd "$(dirname "$0")" && pwd)" +onnx_dir="$dir/sam2_onnx_models" -if ! [ -d $install_dir ]; then - echo "install_dir: $install_dir does not exist." +# Installation directory (default: $HOME) +install_dir="${1:-$HOME}" + +if [ ! -d "$install_dir" ]; then + echo "Error: install_dir '$install_dir' does not exist." exit 1 fi -# Directory of the sam2 code by "git clone https://github.com/facebookresearch/segment-anything-2" -sam2_dir=$install_dir/segment-anything-2 - -# model name to benchmark -model=sam2_hiera_large +# SAM2 code directory and model to benchmark +sam2_dir="$install_dir/segment-anything-2" +model="sam2_hiera_large" -# Default to use GPU if available. -cpu_or_gpu="gpu" -if [ $# -ge 2 ] && ([ "$2" = "gpu" ] || [ "$2" = "cpu" ]); then - cpu_or_gpu=$2 +# Default to GPU, switch to CPU if specified +cpu_or_gpu="${2:-gpu}" +if [ "$cpu_or_gpu" != "gpu" ] && [ "$cpu_or_gpu" != "cpu" ]; then + echo "Invalid option: $2. Please specify 'cpu' or 'gpu'." + exit 1 fi echo "install_dir: $install_dir" @@ -51,185 +38,183 @@ echo "cpu_or_gpu: $cpu_or_gpu" install_cuda_12() { pushd $install_dir - wget https://developer.download.nvidia.com/compute/cuda/12.5.1/local_installers/cuda_12.5.1_555.42.06_linux.run - sh cuda_12.5.1_555.42.06_linux.run --toolkit --toolkitpath=$install_dir/cuda12.5 --silent --override --no-man-page + wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page - export PATH="$install_dir/cuda12.5/bin:$PATH" - export LD_LIBRARY_PATH="$install_dir/cuda12.5/lib64:$LD_LIBRARY_PATH" + export PATH="$install_dir/cuda12.6/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" popd } -install_cudnn_9() -{ - pushd $install_dir - wget https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz - mkdir $install_dir/cudnn9.4 - tar -Jxvf cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz -C $install_dir/cudnn9.4 --strip=1 --no-overwrite-dir - - export LD_LIBRARY_PATH="$install_dir/cudnn9.4/lib:$LD_LIBRARY_PATH" +# Function to install cuDNN 9.4 +install_cudnn_9() { + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn9.5" + tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn9.5" --strip=1 + export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" popd } -install_gpu() -{ - if ! [ -d $install_dir/cuda12.5 ]; then - install_cuda_12 - fi - - if ! [ -d $install_dir/cudnn9.4 ]; then - install_cudnn_9 - fi +# Install GPU dependencies +install_gpu() { + [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 + [ ! -d "$install_dir/cudnn9.5" ] && install_cudnn_9 pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 pip install onnxruntime-gpu onnx opencv-python matplotlib } -install_cpu() -{ +# Install CPU dependencies +install_cpu() { pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install onnxruntime onnx opencv-python matplotlib } -install_sam2() -{ - pushd $install_dir - - if ! [ -d $install_dir/segment-anything-2 ]; then +# Clone and install SAM2 if not already installed +install_sam2() { + pushd "$install_dir" + if [ ! -d "$sam2_dir" ]; then git clone https://github.com/facebookresearch/segment-anything-2.git fi - - cd segment-anything-2 - - if pip show SAM-2 > /dev/null 2>&1; then - echo "SAM-2 is already installed." - else - pip install -e . - fi - - if ! [ -f checkpoints/sam2_hiera_large.pt ]; then - echo "Downloading checkpoints..." - cd checkpoints - sh ./download_ckpts.sh - fi - + cd "$sam2_dir" + pip show SAM-2 > /dev/null 2>&1 || pip install -e . + [ ! -f checkpoints/sam2_hiera_large.pt ] && (cd checkpoints && sh ./download_ckpts.sh) popd } -download_test_image() -{ - if ! [ -f truck.jpg ]; then - curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg - fi +# Download test image if not available +download_test_image() { + [ ! -f truck.jpg ] && curl -sO https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg } -run_cpu() -{ - repeats=$1 +run_cpu_benchmark() { + local repeats="$1" + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo - $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --demo + for component in image_encoder image_decoder; do + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --dtype fp32 --component "$component" - echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 + # Run ONNX Runtime on exported model (not optimized) + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --onnx_path "${onnx_dir}/${model}_${component}.onnx" --dtype fp32 --component "$component" - echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 --component image_decoder - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 --component image_decoder + # Run ONNX Runtime on optimized model + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --onnx_path "${onnx_dir}/${model}_${component}_fp32_cpu.onnx" --dtype fp32 --component "$component" + done +} - echo "Benchmarking SAM2 model $model image encoder for ORT ..." - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder.onnx --dtype fp32 - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_cpu.onnx --dtype fp32 +run_gpu_benchmark() { + local repeats="$1" + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo - echo "Benchmarking SAM2 model $model image decoder for ORT ..." - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder.onnx --component image_decoder - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_cpu.onnx --component image_decoder -} + for component in image_encoder image_decoder; do + for dtype in bf16 fp32 fp16; do + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" + done + done -run_gpu() -{ - repeats=$1 + component="image_encoder" + for dtype in fp32 fp16; do + #TODO: --prefer_nhwc does not help with performance + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph + done - $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp32 - $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp16 --demo + component="image_decoder" + for dtype in fp32 fp16; do + # TODO: decoder does not work with cuda graph + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" + done +} - echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 +run_torch_compile_gpu_benchmark() { + local repeats="$1" - # Test different torch compile modes on image encoder (none will disable compile and use eager mode). + # Test different torch compile modes on image encoder for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs do - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode done +} - echo "Benchmarking SAM2 model $model image decoder for PyTorch ..." - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 --component image_decoder - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 --component image_decoder - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder - echo "Benchmarking SAM2 model $model image encoder for ORT ..." - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp16_gpu.onnx --use_gpu --dtype fp16 - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_gpu.onnx --use_gpu --dtype fp32 +# Main script +run_benchmarks() { + if [ ! -v CONDA_PREFIX ]; then + echo "Please activate conda environment before running this script." + exit 1 + fi - echo "Benchmarking SAM2 model $model image decoder for ORT ..." - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp16_gpu.onnx --component image_decoder --use_gpu --dtype fp16 - $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_gpu.onnx --component image_decoder --use_gpu + # Install dependencies + [ "$cpu_or_gpu" = "gpu" ] && install_gpu || install_cpu + install_sam2 + download_test_image + + # Run benchmarks + output_csv="sam2_${cpu_or_gpu}.csv" + if [ ! -f "$output_csv" ]; then + echo "Running $cpu_or_gpu benchmark..." + if [ "$cpu_or_gpu" = "gpu" ]; then + run_gpu_benchmark 1000 + run_torch_compile_gpu_benchmark 1000 + else + run_cpu_benchmark 100 + fi + cat benchmark*.csv > combined_csv + awk '!x[$0]++' combined_csv > "$output_csv" + rm combined_csv + echo "Benchmark results saved in $output_csv" + else + echo "$output_csv already exists, skipping benchmark..." + fi } -# Build onnxruntime-gpu from source for profiling. -build_onnxruntime_gpu_for_profiling() -{ - pushd $install_dir +run_benchmarks + +#-------------------------------------------------------------------------- +# Below are for profiling +#-------------------------------------------------------------------------- + +# Build onnxruntime-gpu from source for profiling +build_onnxruntime_gpu_for_profiling() { + pushd "$install_dir" if ! [ -d onnxruntime ]; then git clone https://github.com/microsoft/onnxruntime fi cd onnxruntime - - # Get the CUDA compute capability of the GPU. CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") - if [ -n "$CUDA_ARCH" ]; then pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==1.26.4 sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ - --use_cuda --cuda_version 12.5 --cuda_home $install_dir/cuda12.5 \ - --cudnn_home $install_dir/cudnn9.4 \ + --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ + --cudnn_home $install_dir/cudnn9.5 \ --build_wheel --skip_tests \ --cmake_generator Ninja \ --compile_no_warning_as_error \ - --enable_cuda_nhwc_ops \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ --enable_cuda_line_info pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl numpy==1.26.4 else - echo "PyTorch is not installed or No CUDA device found." + echo "No CUDA device found." exit 1 fi - popd } # Run profiling with NVTX. run_nvtx_profile() { - pip install nvtx cuda-python==12.5.0 + pip install nvtx cuda-python==12.6.0 # Only trace one device to avoid huge output file size. device_id=0 - - # Environment variables envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH" - - # For cuda graphs, node activities will be collected and CUDA graphs will not be traced as a whole. - # This may cause significant runtime overhead. But it is useful to understand the performance of individual nodes. cuda_graph_trace=node - - for engine in ort torch - do - for component in image_encoder image_decoder - do - sudo $install_dir/cuda12.5/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ + for engine in ort torch; do + for component in image_encoder image_decoder; do + sudo $install_dir/cuda12.6/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ --gpu-metrics-device $device_id --force-overwrite true \ --sample process-tree --backtrace fp --stats true \ -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ @@ -246,10 +231,8 @@ run_nvtx_profile() } # Run profiling with PyTorch -run_torch_profile() -{ - for component in image_encoder image_decoder - do +run_torch_profile() { + for component in image_encoder image_decoder; do $python benchmark_sam2.py --model_type $model --engine torch \ --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ --component $component \ @@ -257,50 +240,16 @@ run_torch_profile() done } -if ! [ -v CONDA_PREFIX ]; then - echo "Please activate conda environment before running this script." - exit 1 -fi - -# Check whether nvidia-smi is available to determine whether to install GPU or CPU version. -if [ "$cpu_or_gpu" = "gpu" ]; then - install_gpu -else - install_cpu -fi - -install_sam2 - -download_test_image - -if ! [ -f sam2_${cpu_or_gpu}.csv ]; then - if [ "$cpu_or_gpu" = "gpu" ]; then - echo "Running GPU benchmark..." - run_gpu 1000 - else - echo "Running CPU benchmark..." - run_cpu 100 - fi +run_profilings() { + build_onnxruntime_gpu_for_profiling - cat benchmark*.csv > combined_csv - awk '!x[$0]++' combined_csv > sam2_${cpu_or_gpu}.csv - rm combined_csv - - echo "Benchmarking SAM2 model $model results are saved in sam2_${cpu_or_gpu}.csv" -else - echo "sam2_${cpu_or_gpu}.csv already exists, skipping benchmarking..." -fi - -if [ "$cpu_or_gpu" = "gpu" ]; then - echo "Running GPU profiling..." - if ! [ -f sam2_fp16_profile_image_decoder_ort_${cpu_or_gpu}.nsys-rep ]; then - rm -f *.nsys-rep - rm -f *.sqlite - build_onnxruntime_gpu_for_profiling - run_nvtx_profile - else - echo "sam2_fp16_profile_image_decoder_ort_${cpu_or_gpu}.nsys-rep already exists, skipping GPU profiling..." - fi + rm -f *.nsys-rep *.sqlite + run_nvtx_profile run_torch_profile +} + +profiling="${3:-false}" +if [ "$profiling" = "true" ] && [ "$cpu_or_gpu" = "gpu" ]; then + run_profilings fi diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 9c1c31626066d..edef0d3ee5453 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -40,9 +40,8 @@ docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:24.04-p ``` #### Build onnxruntime from source -The cuDNN in the container might not be compatible with official onnxruntime-gpu package, it is recommended to build from source instead. +This step is optional. Please look at [install onnxruntime-gpu](https://onnxruntime.ai/docs/install/#python-installs) if you do not want to build from source. -After launching the docker, you can build and install onnxruntime-gpu wheel like the following. ``` export CUDACXX=/usr/local/cuda/bin/nvcc git config --global --add safe.directory '*' @@ -60,9 +59,17 @@ If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line If your machine has less than 64GB memory, replace `--parallel` by `--parallel 4 --nvcc_threads 1 ` to avoid out of memory. #### Install required packages +First, remove older version of opencv to avoid error like `module 'cv2.dnn' has no attribute 'DictValue'`: +``` +pip uninstall -y $(pip list --format=freeze | grep opencv) +rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ +apt-get update +DEBIAN_FRONTEND="noninteractive" apt-get install --yes python3-opencv +``` + ``` cd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion -python3 -m pip install -r requirements-cuda12.txt +python3 -m pip install -r requirements/cuda12/requirements.txt python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com ``` @@ -136,15 +143,18 @@ conda activate py310 ### Setup Environment (CUDA) without docker -First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. +First, we need install CUDA 11.8 or 12.x, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html), and [TensorRT](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. + +The verison of CuDNN can be found in https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements. +The version of TensorRT can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#requirements. #### CUDA 11.8: -In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: +In the Conda environment, install PyTorch 2.1 up to 2.3.1, and other required packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/cu118 +pip install torch>=2.1,<2.4 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com -pip install -r requirements-cuda11.txt +pip install -r requirements/cuda11/requirements.txt ``` For Windows, install nvtx like the following: @@ -157,77 +167,40 @@ We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. Like `pip install tensorrt-8.6.1.6.windows10.x86_64.cuda-11.8\tensorrt-8.6.1.6\python\tensorrt-8.6.1-cp310-none-win_amd64.whl`. #### CUDA 12.*: -The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). - -``` -git clone --recursive https://github.com/Microsoft/onnxruntime.git -cd onnxruntime -pip install cmake -pip install -r requirements-dev.txt -``` -Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) -or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxruntime/blob/8df5f4e0df1f3b9ceeb0f1f2561b09727ace9b37/build_trt.cmd) to build and install onnxruntime-gpu wheel. - -Then install other python packages like the following: +The official package of onnxruntime-gpu 1.19.x is built for CUDA 12.x. You can install it and other python packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/cu121 +pip install onnxruntime-gpu +pip install torch --index-url https://download.pytorch.org/whl/cu124 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com -pip install -r requirements-cuda12.txt +pip install -r requirements/cuda12/requirements.txt ``` Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. ### Setup Environment (ROCm) -It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.10. +It is recommended that the users run the model with ROCm 6.2 or newer and Python 3.10. You can follow the following to install ROCm 6.x: https://rocmdocs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html Note that Windows is not supported for ROCm at the moment. ``` -wget https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/torch-1.12.1%2Brocm5.4-cp38-cp38-linux_x86_64.whl -pip install torch-1.12.1+rocm5.4-cp38-cp38-linux_x86_64.whl -pip install -r requirements-rocm.txt +pip install -r requirements/rocm/requirements.txt ``` -AMD GPU version of PyTorch can be installed from [pytorch.org](https://pytorch.org/get-started/locally/) or [AMD Radeon repo](https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/). +AMD GPU version of PyTorch can be installed from [pytorch.org](https://pytorch.org/get-started/locally/) or [AMD Radeon repo](https://repo.radeon.com/rocm/manylinux/rocm-rel-6.2.3/). #### Install onnxruntime-rocm -Here is an example to build onnxruntime from source with Rocm 5.4.2 in Ubuntu 20.04, and install the wheel. - -(1) Install [ROCm 5.4.2](https://docs.amd.com/bundle/ROCm-Installation-Guide-v5.4.2/page/How_to_Install_ROCm.html). Note that the version is also used in PyTorch 2.0 ROCm package. - -(2) Install some tools used in build: -``` -sudo apt-get update -sudo apt-get install -y --no-install-recommends \ - wget \ - zip \ - ca-certificates \ - build-essential \ - curl \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev -pip install numpy packaging "wheel>=0.35.1" -wget --quiet https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz -tar zxf cmake-3.26.3-linux-x86_64.tar.gz -export PATH=${PWD}/cmake-3.26.3-linux-x86_64/bin:${PATH} -``` - -(3) Build and Install ONNX Runtime +One option is to install prebuilt wheel from https://repo.radeon.com/rocm/manylinux like: ``` -git clone https://github.com/microsoft/onnxruntime -cd onnxruntime -sh build.sh --config Release --use_rocm --rocm_home /opt/rocm --rocm_version 5.4.2 --build_wheel -pip install build/Linux/Release/dist/*.whl +wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.2.3/onnxruntime_rocm-1.18.0-cp310-cp310-linux_x86_64.whl +pip install onnxruntime_rocm-1.18.0-cp310-cp310-linux_x86_64.whl ``` -You can also follow the [official docs](https://onnxruntime.ai/docs/build/eps.html#amd-rocm) to build with docker. +If you want to use latest version of onnxruntime, you can build from source with Rocm 6.x following https://onnxruntime.ai/docs/build/eps.html#amd-rocm. +When the build is finished, you can install the wheel:`pip install build/Linux/Release/dist/*.whl`. ### Export ONNX pipeline This step will export stable diffusion 1.5 to ONNX model in float32 using script from diffusers. -It is recommended to use PyTorch 1.12.1 or 1.13.1 in this step. Using PyTorch 2.0 will encounter issue in exporting onnx. - ``` curl https://raw.githubusercontent.com/huggingface/diffusers/v0.15.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py > convert_sd_onnx.py python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd_v1_5/fp32 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 3879e25386d53..0708d57f040f8 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -51,6 +51,10 @@ def example_prompts(): return prompts, negative_prompt +def warmup_prompts(): + return "warm up", "bad" + + def measure_gpu_memory(monitor_type, func, start_memory=None): return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) @@ -136,7 +140,14 @@ def run_ort_pipeline( prompts, negative_prompt = example_prompts() def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs # cuDNN/MIOpen The first run has algo search so it might need more memory) @@ -149,22 +160,20 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - images = pipe( - [prompt] * batch_size, - height, - width, - num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, - guidance_scale=7.5, - ).images - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") from onnxruntime import __version__ as ort_version @@ -200,7 +209,14 @@ def run_torch_pipeline( # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) @@ -215,25 +231,23 @@ def warmup(): if i >= num_prompts: break torch.cuda.synchronize() - for j in range(batch_count): - inference_start = time.time() - images = pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - guidance_scale=7.5, - negative_prompt=[negative_prompt] * batch_size, - generator=None, # torch.Generator - ).images + inference_start = time.time() + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + generator=None, # torch.Generator + ).images - torch.cuda.synchronize() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + torch.cuda.synchronize() + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") return { "engine": "torch", @@ -306,6 +320,7 @@ def get_optimum_ort_pipeline( directory: str, provider="CUDAExecutionProvider", disable_safety_checker: bool = True, + use_io_binding: bool = False, ): from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline @@ -321,7 +336,7 @@ def get_optimum_ort_pipeline( pipeline = ORTStableDiffusionPipeline.from_pretrained( directory, provider=provider, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. + use_io_binding=use_io_binding, ) elif "xl" in model_name: pipeline = ORTStableDiffusionXLPipeline.from_pretrained( @@ -337,7 +352,7 @@ def get_optimum_ort_pipeline( model_name, export=True, provider=provider, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. + use_io_binding=use_io_binding, ) pipeline.save_pretrained(directory) @@ -359,15 +374,33 @@ def run_optimum_ort_pipeline( batch_count, start_memory, memory_monitor_type, + use_num_images_per_prompt=False, ): from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline assert isinstance(pipe, (ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline)) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + if use_num_images_per_prompt: + pipe( + prompt=prompt, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=negative, + num_images_per_prompt=batch_count, + ) + else: + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs. # The first run has algo search for cuDNN/MIOpen, so it might need more memory. @@ -380,23 +413,30 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() + inference_start = time.time() + if use_num_images_per_prompt: images = pipe( - prompt, - height, - width, + prompt=prompt, + height=height, + width=width, num_inference_steps=steps, - negative_prompt=None, - guidance_scale=0.0, # 7.5 + negative_prompt=negative_prompt, num_images_per_prompt=batch_size, ).images - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + else: + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") from onnxruntime import __version__ as ort_version @@ -429,9 +469,12 @@ def run_optimum_ort( batch_count: int, start_memory, memory_monitor_type, + use_io_binding: bool = False, ): load_start = time.time() - pipe = get_optimum_ort_pipeline(model_name, directory, provider, disable_safety_checker) + pipe = get_optimum_ort_pipeline( + model_name, directory, provider, disable_safety_checker, use_io_binding=use_io_binding + ) load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") @@ -530,9 +573,8 @@ def run_ort_trt_static( pipeline.load_resources(height, width, batch_size) def warmup(): - pipeline.run( - ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True - ) + prompt, negative = warmup_prompts() + pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -548,24 +590,23 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - images, pipeline_time = pipeline.run( - [prompt] * batch_size, - [negative_prompt] * batch_size, - height, - width, - denoising_steps=steps, - guidance=7.5, - seed=123, - ) - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + ) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") pipeline.teardown() @@ -671,9 +712,8 @@ def run_tensorrt_static( pipeline.load_resources(height, width, batch_size) def warmup(): - pipeline.run( - ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True - ) + prompt, negative = warmup_prompts() + pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -689,24 +729,22 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - images, pipeline_time = pipeline.run( - [prompt] * batch_size, - [negative_prompt] * batch_size, - height, - width, - denoising_steps=steps, - guidance=7.5, - seed=123, - ) - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + seed=123, + ) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") pipeline.teardown() @@ -828,7 +866,8 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None): ) def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) + prompt, negative = warmup_prompts() + run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -845,20 +884,15 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - if nvtx_profile: - cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) - if nvtx_profile: - cudart.cudaProfilerStop() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.png") pipeline.teardown() @@ -911,8 +945,6 @@ def run_ort_trt_xl( opt_batch_size=batch_size, ) - from cuda import cudart - assert batch_size <= max_batch_size pipeline.load_resources(height, width, batch_size) @@ -929,7 +961,8 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None): ) def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) + prompt, negative = warmup_prompts() + run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -946,22 +979,17 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - if nvtx_profile: - cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) - if nvtx_profile: - cudart.cudaProfilerStop() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" - image.save(filename) - print("Image saved to", filename) + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{k}.png" + image.save(filename) + print("Image saved to", filename) pipeline.teardown() @@ -1137,6 +1165,14 @@ def parse_arguments(): ) parser.set_defaults(use_xformers=False) + parser.add_argument( + "--use_io_binding", + required=False, + action="store_true", + help="Use I/O Binding for Optimum.", + ) + parser.set_defaults(use_io_binding=False) + parser.add_argument( "-b", "--batch_size", @@ -1176,8 +1212,8 @@ def parse_arguments(): "--num_prompts", required=False, type=int, - default=1, - help="Number of prompts. Default is 1.", + default=10, + help="Number of prompts. Default is 10.", ) parser.add_argument( @@ -1312,6 +1348,7 @@ def main(): batch_count=args.batch_count, start_memory=start_memory, memory_monitor_type=memory_monitor_type, + use_io_binding=args.use_io_binding, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 0b6d325803554..7609ae10fc96d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -117,7 +117,7 @@ def get_cached_model_name(self, model_name): model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet) if hash_source: - model_name += "_" + hashlib.md5("\t".join(hash_source).encode("utf-8")).hexdigest()[:8] + model_name += "_" + hashlib.sha256("\t".join(hash_source).encode("utf-8")).hexdigest()[:8] # TODO: When we support original VAE, we shall save custom VAE to another directory. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt deleted file mode 100644 index c0a925e25b941..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt +++ /dev/null @@ -1,5 +0,0 @@ --r requirements.txt -# Install onnxruntime-rocm or onnxruntime_training -# Build onnxruntime-rocm from source -# Directly install pre-built onnxruntime/onnxruntime-training rocm python package is not possible at the moment. -# TODO: update once we have public pre-built packages diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda11/requirements.txt similarity index 64% rename from onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt rename to onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda11/requirements.txt index 4aa88cdf92309..bbc62ca4cbd18 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda11/requirements.txt @@ -1,13 +1,13 @@ --r requirements.txt +-r ../requirements.txt -# For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. +# See https://onnxruntime.ai/docs/install/#python-installs for installation. The latest one in pypi is for cuda 12. # onnxruntime-gpu>=1.16.2 py3nvml # The version of cuda-python shall be compatible with installed CUDA version. # For demo of TensorRT excution provider and TensortRT. -cuda-python>=12.1.0 +cuda-python==11.8.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" @@ -15,8 +15,8 @@ pywin32; platform_system == "Windows" # For windows, run `conda install -c conda-forge nvtx` instead nvtx; platform_system != "Windows" -# Please install PyTorch 2.1 or above for 12.1 using one of the following commands: -# pip3 install torch --index-url https://download.pytorch.org/whl/cu121 +# Please install PyTorch >=2.1 and <2.4 for CUDA 11.8 like the following: +# pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cu118 # Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. # pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda12/requirements.txt similarity index 73% rename from onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt rename to onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda12/requirements.txt index dc6592fc2fa54..89562e920ac00 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/cuda12/requirements.txt @@ -1,13 +1,12 @@ --r requirements.txt +-r ../requirements.txt -# Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. -onnxruntime-gpu>=1.16.2 +onnxruntime-gpu>=1.19.2 py3nvml # The version of cuda-python shall be compatible with installed CUDA version. # For demo of TensorRT excution provider and TensortRT. -cuda-python==11.8.0 +cuda-python>=12.1.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" @@ -15,8 +14,8 @@ pywin32; platform_system == "Windows" # For windows, run `conda install -c conda-forge nvtx` instead nvtx; platform_system != "Windows" -# Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: -# pip3 install torch --index-url https://download.pytorch.org/whl/cu118 +# Please install PyTorch 2.4 or above using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu124 # Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. # pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt similarity index 63% rename from onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt rename to onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 72ba4252e481c..5bdd422a11750 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -1,3 +1,4 @@ +huggingface_hub==0.25.2 diffusers==0.28.0 transformers==4.41.2 numpy>=1.24.1 @@ -9,10 +10,13 @@ packaging protobuf==3.20.3 psutil sympy -controlnet_aux==0.0.7 +nvtx==0.2.5 +torchvision==0.15.2 +tensorrt==8.5.1.7 +mediapipe +controlnet_aux==0.0.9 # The following are for SDXL optimum==1.20.0 safetensors invisible_watermark -# newer version of opencv-python migth encounter module 'cv2.dnn' has no attribute 'DictValue' error -opencv-python==4.8.0.74 +opencv-python-headless diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt new file mode 100644 index 0000000000000..21b100fb61f17 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/rocm/requirements.txt @@ -0,0 +1,2 @@ +-r ../requirements.txt +# Install onnxruntime-rocm that is built from source (https://onnxruntime.ai/docs/build/eps.html#amd-rocm) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt index e51ffb395c643..1938f59208ae7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt @@ -2,3 +2,4 @@ git+https://github.com/openai/CLIP.git open_clip_torch sentence_transformers pillow +numpy==1.22.2 diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index 814b0dd1ef6ac..b7f5c2294f395 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -291,11 +291,7 @@ def save_results(results, filename): installed_packages = pkg_resources.working_set installed_packages_list = sorted( - [ - f"{i.key}=={i.version}" - for i in installed_packages - if i.key in ["ort-nightly-gpu", "ort-nightly", "onnxruntime", "onnxruntime-gpu"] - ] + [f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]] ) ort_pkg_name = "" ort_pkg_version = "" diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index be05ebc9d5dac..87ac45101f0c0 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -323,4 +323,7 @@ def chain_model(args): convert_attribute=True, location=f"{os.path.basename(args.beam_model_output_dir)}.data", ) - onnx.checker.check_model(args.beam_model_output_dir, full_check=True) + try: + onnx.checker.check_model(args.beam_model_output_dir, full_check=True) + except Exception as e: + logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201 diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 3967a7875f3a7..c3ccde50dac85 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -392,11 +392,13 @@ def validate_and_optimize_onnx( False, output_names, ) - if optimize_info == OptimizerInfo.NOOPT: + if optimize_info.name == OptimizerInfo.NOOPT.name: return onnx_model_path, is_valid_onnx_model, config.vocab_size if ( - optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8 + optimize_info.name == OptimizerInfo.BYSCRIPT.name + or precision == Precision.FLOAT16 + or precision == Precision.INT8 ): # Use script (optimizer.py) to optimize optimized_model_path = get_onnx_file_path( onnx_dir, @@ -439,7 +441,7 @@ def validate_and_optimize_onnx( QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format) logger.info(f"Finished quantizing model: {onnx_model_path}") - if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize + if optimize_info.name == OptimizerInfo.BYORT.name: # Use OnnxRuntime to optimize if is_valid_onnx_model: ort_model_path = add_filename_suffix(onnx_model_path, "_ort") optimize_onnx_model_by_ort( @@ -492,7 +494,7 @@ def export_onnx_model_from_pt( example_inputs = image_processor(data, return_tensors="pt") else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -596,7 +598,7 @@ def export_onnx_model_from_tf( # Fix "Using pad_token, but it is not set yet" error. if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier) model.resize_token_embeddings(len(tokenizer)) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index c781a91c9e493..efcd92129597a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -178,18 +178,17 @@ def fuse_attention(self): mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") attention_node = self.attention_fusion.create_attention_node( - mask_index, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - self.num_heads, - self.hidden_size, - parent.output[0], - reshape_qkv.output[0], - None, + mask_index=mask_index, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=self.num_heads, + hidden_size=self.hidden_size, + first_input=parent.output[0], + output=reshape_qkv.output[0], ) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index b7891223e1dc2..a89b6c9e9395d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -480,18 +480,17 @@ def fuse_attention(self): # For tf models, q and v are flipped. attention_node = self.attention_fusion.create_attention_node( - mask_index, - matmul_k, - matmul_q, - matmul_v, - add_k, - add_q, - add_v, - self.num_heads, - self.hidden_size, - parent.output[0], - qkv_nodes[2].output[0], - None, + mask_index=mask_index, + q_matmul=matmul_k, + k_matmul=matmul_q, + v_matmul=matmul_v, + q_add=add_k, + k_add=add_q, + v_add=add_v, + num_heads=self.num_heads, + hidden_size=self.hidden_size, + first_input=parent.output[0], + output=qkv_nodes[2].output[0], ) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 77d0c3a76624f..25997f40d348f 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -1,13 +1,11 @@ +#!/bin/bash # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- # This measures the performance of OnnxRuntime, PyTorch and TorchScript on transformer models. -# Please install PyTorch (see https://pytorch.org/) before running this benchmark. Like the following: -# GPU: conda install pytorch torchvision cudatoolkit=11.0 -c pytorch -# CPU: conda install pytorch torchvision cpuonly -c pytorch -# To use torch2, please install the nightly PyTorch by replacing pytorch with pytorch-nightly. +# Please install PyTorch (see https://pytorch.org/) before running this benchmark. # When use_package=true, you need not copy other files to run benchmarks except this sh file. # Otherwise, it will use python script (*.py) files in this directory. @@ -59,7 +57,6 @@ sequence_lengths="8 16 32 64 128 256 512 1024" # Here we only test one input (input_ids) for fair comparison with PyTorch. input_counts=1 -# Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased models_to_test="bert-base-cased roberta-base distilbert-base-uncased" # If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: @@ -91,7 +88,6 @@ fi if [ "$run_install" = true ] ; then - pip uninstall --yes ort-nightly ort-gpu-nightly pip uninstall --yes onnxruntime pip uninstall --yes onnxruntime-gpu if [ "$run_cpu_fp32" = true ] || [ "$run_cpu_int8" = true ]; then @@ -99,7 +95,7 @@ if [ "$run_install" = true ] ; then else pip install onnxruntime-gpu fi - pip install --upgrade onnx coloredlogs packaging psutil py3nvml onnxconverter_common numpy transformers sympy + pip install --upgrade onnx coloredlogs packaging psutil py3nvml numpy transformers sympy fi if [ "$use_package" = true ] ; then diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index e0891c7ced63e..acb520f894569 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -194,6 +194,24 @@ inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tens } } +template +std::vector GetTypedArray(std::vector inputs) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_integral_v, + "Only float, double, MLFloat16, and integral types are supported."); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v || std::is_same::value) { + std::vector result(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + result[i] = static_cast(inputs[i]); + } + return result; + } else { + return ToFloat16(inputs); + } +} + class ParallelRandomValueGenerator { public: using RandomEngine = std::default_random_engine; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 5f94d30112f0e..9f4ee071925b4 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -388,5 +390,47 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, DummyT5) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 17c9e8592f64e..208545eacf224 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -15,23 +15,20 @@ namespace onnxruntime { namespace test { -// This op is currently only supported on CUDA- so test it only for CUDA -#ifdef USE_CUDA - template static std::vector CreateOnes(int size) { std::vector f; f.reserve(size); for (int i = 0; i < size; ++i) { - f.push_back(T(1)); + f.push_back(T(1.0f)); } return f; } template -static std::vector CreateValues(int size, int val) { +static std::vector CreateValues(int size, float val) { std::vector f; f.reserve(size); @@ -72,39 +69,25 @@ static std::vector CreateRandom(int size) { return f; } -// QKV template -static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size); +float ToFloat(T val); template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, 0); - - for (int b = 0; b < batch_size; ++b) { - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < 3 * hidden_size; ++j) { - float sum = 0; - - for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k] * weights[k * 3 * hidden_size + j]; - } - - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = sum + bias[j]; - } - } - } - - return qkv; +constexpr float ToFloat(float val) { + return val; } template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); +float ToFloat(MLFloat16 val) { + return val.ToFloat(); +} + +// QKV +template +static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, + int batch_size, int sequence_length, int hidden_size) { + std::vector qkv; + qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int i = 0; i < sequence_length; ++i) { @@ -112,10 +95,11 @@ std::vector QKV(std::vector& input, std::vector float sum = 0; for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k].ToFloat() * weights[k * 3 * hidden_size + j].ToFloat(); + sum += ToFloat(input[b * sequence_length * hidden_size + i * hidden_size + k]) * + ToFloat(weights[k * 3 * hidden_size + j]); } - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + bias[j].ToFloat()); + qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + ToFloat(bias[j])); } } } @@ -180,15 +164,17 @@ void CheckEquality(T* data_1, T* data_2, int batch_size, int num_heads, int num_ // Reorder 'K' from [B, N, S, H] to [B, N, H/x, S, x] where x = (sizeof(T) / 16); // Copy 'V' over as is template -static std::vector ReorderKVCache(std::vector& unordered_k_cache, +static std::vector ReorderKVCache(const std::vector& unordered_k_cache, int batch_size, int num_heads, int sequence_length, - int head_size, int max_sequence_length) { + int head_size, int max_sequence_length, bool merge_past_kv = true) { std::vector ordered(unordered_k_cache.size(), T{0.f}); // Copy V over - size_t v_start = unordered_k_cache.size() / 2; - for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { - ordered[i] = unordered_k_cache[i]; + if (merge_past_kv) { + size_t v_start = unordered_k_cache.size() / 2; + for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { + ordered[i] = unordered_k_cache[i]; + } } // Now let us re-order K and copy it over to the final buffer @@ -203,7 +189,8 @@ static std::vector ReorderKVCache(std::vector& unordered_k_cache, (h * max_sequence_length * head_size); int input_base_offset = base_offset + (s * head_size) + (c * num_inner_elements); - int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + (s * num_inner_elements); + int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + + (s * num_inner_elements); for (int e = 0; e < num_inner_elements; ++e) { ordered[output_base_offset + e] = unordered_k_cache[input_base_offset + e]; @@ -224,7 +211,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache T* k, int batch_size, int num_heads, int past_sequence_length, int max_sequence_length, - int head_size) { + int head_size, bool merge_past_kv = true) { std::vector merged = ordered_k_cache; int total_seq_length = past_sequence_length + 1; @@ -249,10 +236,11 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache input_value = ordered_k_cache[input_offset]; } else { int hidden_size = num_heads * head_size; - int input_offset = (b * 3 * hidden_size) + - (n * num_chunks * chunk_size) + - (c * chunk_size) + - h; + int input_offset = merge_past_kv ? ((b * 3 * hidden_size) + + (n * num_chunks * chunk_size) + + (c * chunk_size) + + h) + : ((b * hidden_size) + n * head_size + c * chunk_size + h); input_value = k[input_offset]; } @@ -272,7 +260,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache return merged; } -// GIven a pointer to the 'V' component of the past cache, we will merge it +// Given a pointer to the 'V' component of the past cache, we will merge it // with current 'V' in-place template static void MergeReorderedKVCacheWithV(T* v_cache, @@ -299,7 +287,8 @@ static void MergeReorderedKVCacheWithV(T* v_cache, template static std::pair, std::vector> MergePastKWithPresentKAndTranspose(T* past_k, T* present_k, int num_batch, int num_heads, - int past_sequence_length, int max_sequence_length, + int past_sequence_length, + int max_sequence_length, int head_size) { int total_seq_length = (past_sequence_length + 1); std::vector merged_k(num_batch * num_heads * total_seq_length * head_size, T{0.f}); @@ -312,16 +301,18 @@ static std::pair, std::vector> MergePastKWithPresentKAndTransp T input_value{0.f}; if (s < past_sequence_length) { - int input_offset = b * num_heads * max_sequence_length * head_size + (n * max_sequence_length * head_size) + (s * head_size) + h; + int input_offset = b * num_heads * max_sequence_length * head_size + + (n * max_sequence_length * head_size) + (s * head_size) + h; input_value = past_k[input_offset]; } else { int hidden_size = num_heads * head_size; - // Offset by 3* hidden_size because QKV data contains Q, K, and V per batch + // Offset by 3 * hidden_size because QKV data contains Q, K, and V per batch int input_offset = (b * 3 * hidden_size) + (n * head_size) + h; input_value = present_k[input_offset]; } - int output_offset = b * num_heads * total_seq_length * head_size + (n * total_seq_length * head_size) + (s * head_size) + h; + int output_offset = b * num_heads * total_seq_length * head_size + + (n * total_seq_length * head_size) + (s * head_size) + h; merged_k[output_offset] = input_value; } @@ -383,15 +374,11 @@ void ValidateReorderedMergedKWithK(T* k, T* k_cache, int batch_size, int num_hea // QK_Transpose template std::vector QK_Transpose(T* q_matrix, T* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size); - -template <> -std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { + int batch_size, int num_heads, int total_sequence_length, int head_size) { int hidden_size = num_heads * head_size; - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, 0); + std::vector qk_transpose; + qk_transpose.resize(batch_size * num_heads * total_sequence_length, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -409,50 +396,12 @@ std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, for (int j = 0; j < total_sequence_length; ++j) { float sum = 0; for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k] * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j]); + sum += (ToFloat(q_matrix[input_1_base_offset + i * head_size + k]) * + ToFloat(k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j])); } float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = scale * sum; - } - } - } - } - - return qk_transpose; -} - -template <> -std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { - int hidden_size = num_heads * head_size; - - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * 3 * hidden_size) + - (n * head_size); - - int input_2_base_offset = (b * num_heads * total_sequence_length * head_size) + - (n * total_sequence_length * head_size); - - int output_base_offset = (b * num_heads * total_sequence_length) + - (n * total_sequence_length); - - // sequence_length == 1 - for (int i = 0; i < 1; ++i) { - for (int j = 0; j < total_sequence_length; ++j) { - float sum = 0; - for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k].ToFloat() * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j].ToFloat()); - } - - float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = MLFloat16(scale * sum); + qk_transpose[output_base_offset + i * total_sequence_length + j] = static_cast(scale * sum); } } } @@ -464,26 +413,23 @@ std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int head_size); - -template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { + int sequence_length, int total_sequence_length) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, 0); + std::vector softmax_qk_transpose; + softmax_qk_transpose.resize(static_cast(batch_size) * num_heads * sequence_length * total_sequence_length, + static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { int base_offset = (b * num_heads * sequence_length * total_sequence_length) + (n * sequence_length * total_sequence_length); - float max = std::numeric_limits::min(); + float max = std::numeric_limits::lowest(); for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); if (val > max) { max = val; } @@ -491,52 +437,13 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_si float denom = 0; for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); denom += std::exp(val - max); } for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; - softmax_qk_transpose[base_offset + s] = std::exp(val - max) / (denom + (float)0.000001); - } - } - } - - return softmax_qk_transpose; -} - -template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - float max = std::numeric_limits::min(); - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - if (val > max) { - max = val; - } - } - - float denom = 0; - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - denom += std::exp(val - max); - } - - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - softmax_qk_transpose[base_offset + s] = MLFloat16(std::exp(val - max) / (denom + (float)0.000001)); + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); + softmax_qk_transpose[base_offset + s] = static_cast(std::exp(val - max) / (denom + (float)0.000001)); } } } @@ -550,19 +457,13 @@ std::vector Softmax_QK_Transpose_V(T* softmax_qk_transpose_matrix, T* v_matrix, int batch_size, int num_heads, int sequence_length, int total_sequence_length, int max_sequence_length, - int head_size); -template <> -std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, - float* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { + int head_size) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, 0); + std::vector output; + output.resize(batch_size * sequence_length * num_heads * head_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -580,11 +481,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, float sum = 0; for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k] * - v_matrix[input_2_base_offset + k * head_size + j]); + sum += (ToFloat(softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k]) * + ToFloat(v_matrix[input_2_base_offset + k * head_size + j])); } - output[output_base_offset + i * head_size + j] = sum; + output[output_base_offset + i * head_size + j] = static_cast(sum); } } } @@ -593,48 +494,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, return output; } -template <> -std::vector Softmax_QK_Transpose_V(MLFloat16* softmax_qk_transpose_matrix, - MLFloat16* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - int input_2_base_offset = (b * num_heads * max_sequence_length * head_size) + - (n * max_sequence_length * head_size); - - int output_base_offset = (b * num_heads * sequence_length * head_size) + - (n * sequence_length * head_size); - - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < head_size; ++j) { - float sum = 0; - - for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k].ToFloat() * - v_matrix[input_2_base_offset + k * head_size + j].ToFloat()); - } - - output[output_base_offset + i * head_size + j] = MLFloat16(sum); - } - } - } - } +// Currently we only support CUDA for DecoderMaskedSelfAttention +#ifdef USE_CUDA - return output; -} -TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { +template +static void TestDecoderMaskedSelfAttention() { // The kernel is only supported on CC 5.3 or higher GPUs if (NeedSkipIfCudaArchLowerThan(530)) { return; @@ -661,19 +525,19 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { }; constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; + constexpr int num_heads = 12; for (MyTestCase test_case : test_cases) { int batch_size = test_case.batch_size; int past_sequence_length = test_case.past_sequence_length; int hidden_size = test_case.hidden_size; - int head_size = (hidden_size / number_of_heads); + int head_size = (hidden_size / num_heads); int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("past_present_share_buffer", static_cast(1)); std::vector input_dims = {batch_size, sequence_length, hidden_size}; @@ -681,38 +545,38 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { std::vector bias_dims = {3 * hidden_size}; std::vector output_dims = {batch_size, sequence_length, hidden_size}; - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); // Mask tester.AddOptionalInputEdge(); // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + std::vector past_dims = {2, batch_size, num_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * num_heads * max_sequence_length * head_size; - auto kv_cache = CreateRandom(past_present_size); + auto kv_cache = CreateRandom(past_present_size); - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + num_heads, past_sequence_length, head_size, max_sequence_length); // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); + int chunk_size = 16 / sizeof(T); int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + auto transposed = Transpose(kv_cache.data(), batch_size, num_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, num_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); // Rel - tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); // Past sequence length std::vector arr_past_sequence_len(1, past_sequence_length); @@ -722,41 +586,44 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); auto k_merged = pair.first; auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, num_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, num_heads, + sequence_length, total_sequence_length); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); // Validate our test logic // We want to validate if our merged "unordered" K is the same as // the merged "ordered" K so that the QKT we do in our test code // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, num_heads, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, num_heads, sequence_length, total_sequence_length, + max_sequence_length, head_size); // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("output", input_dims, output); + tester.AddOutput("present", past_dims, present); - tester.SetOutputTolerance(0.001f, 0.001f); + if (std::is_same::value) { + tester.SetOutputTolerance(0.005f); + } else { + tester.SetOutputTolerance(0.001f, 0.001f); + } // Run - Regular kernel execution path { @@ -778,150 +645,292 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { } } -TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { - // The kernel is only supported on CC 5.3 or higher GPUs - if (NeedSkipIfCudaArchLowerThan(530)) { - return; - } - - // Buckets for test data: - // batch_size: 1, >=2 - // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) - // head_size: 32, 64, 128 - struct MyTestCase { - int batch_size; - int past_sequence_length; - int hidden_size; - } test_cases[] = { - {1, 0, 768}, - {1, 1, 768}, - {3, 30, 384}, - {8, 31, 1536}, - {4, 256, 384}, - {3, 1024, 768}, - {2, 2046, 1536}, - {1, 2047, 384}, - {2, 3000, 768}, - }; - - constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; - - for (MyTestCase test_case : test_cases) { - int batch_size = test_case.batch_size; - int past_sequence_length = test_case.past_sequence_length; - int hidden_size = test_case.hidden_size; +#endif // USE_CUDA - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); +template +static std::vector CalculateOutputQK(const std::vector& q, const std::vector& k, + const std::vector& mask_index, const std::vector& attention_bias, + int batch_size, int num_heads, + int sequence_length, int max_sequence_length, int head_size) { + // q (B, 1, NH), k (B, N, L(M), H) -> qk (B, N, 1, L) + // mask_index (B, L), (optional) attention_bias (1, 1, 1, L) + float scale = 1 / sqrt(static_cast(head_size)); + std::vector output_qk; + output_qk.resize(static_cast(batch_size) * num_heads * sequence_length, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < sequence_length; ++s) { + float mask_value = (mask_index[b * sequence_length + s] == 0) ? -10000.f : 0.f; + float bias_value = (attention_bias.empty()) ? 0.f : ToFloat(attention_bias[s]); + float sum = 0; + for (int h = 0; h < head_size; ++h) { + sum += ToFloat(q[b * num_heads * head_size + n * head_size + h]) * + ToFloat(k[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + output_qk[b * num_heads * sequence_length + n * sequence_length + s] = + static_cast(scale * sum + mask_value + bias_value); + } + } + } - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + return output_qk; +} - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); +template +static std::vector CalculateOutput(const std::vector& softmax, const std::vector& v, int batch_size, + int num_heads, int sequence_length, int max_sequence_length, int head_size) { + // softmax (B, N, 1, L) v (B, N, L(M), H) -> output (B, N, 1, H) + std::vector output; + output.resize(static_cast(batch_size) * num_heads * head_size, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + float sum = 0; + for (int s = 0; s < sequence_length; ++s) { + sum += ToFloat(softmax[b * num_heads * sequence_length + n * sequence_length + s]) * + ToFloat(v[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + output[b * num_heads * head_size + n * head_size + h] = static_cast(sum); + } + } + } - // Mask - tester.AddOptionalInputEdge(); + return output; +} - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; +template +static std::vector MergePast(const std::vector& past, const std::vector& current, int batch_size, + int num_heads, int past_seq_len, int max_seq_len, int head_size) { + // past (B, N, S(M), H), current (B, 1, NH) -> merged (B, N, S+1(M), H) + std::vector merged = past; + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + merged[b * num_heads * max_seq_len * head_size + n * max_seq_len * head_size + past_seq_len * head_size + h] = + current[b * num_heads * head_size + n * head_size + h]; + } + } + } - auto kv_cache = CreateRandom(past_present_size); + return merged; +} - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); +template +static std::vector ReorderKVByCacheIndirection(const std::vector& key_or_value, + const int32_t* cache_indirection, + int batch_size, int beam_width, int max_sequence_length, + int num_heads, int head_size, int past_sequence_length) { + std::vector reordered = key_or_value; - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + for (int b = 0; b < batch_size; ++b) { + int beam_batch_index = b / beam_width; + const int* beam_indices = cache_indirection + b * max_sequence_length; + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < past_sequence_length; ++s) { + int beam_offset = beam_indices[s] * num_heads * max_sequence_length * head_size; + int beam_batch_offset = (beam_batch_index * beam_width * num_heads + n) * max_sequence_length * head_size; + for (int h = 0; h < head_size; ++h) { + reordered[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h] = + key_or_value[beam_offset + beam_batch_offset + s * head_size + h]; + } + } + } + } - tester.AddInput("past", past_dims, reordered_kv_cache); + return reordered; +} - // Rel - tester.AddOptionalInputEdge(); +template +static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool use_cuda = true) { + int batch_size = 8; + int past_sequence_length = 2; + int kv_sequence_length = 16; + int head_size = 32; + int num_heads = 12; + int beam_width = 4; + int hidden_size = head_size * num_heads; + + OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); + FixedPatternValueGenerator generator{}; + RandomValueGenerator random{123}; + + // Attributes + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(!is_cross_attn)); + // Output scaled Q * K^T by default for cross-attention + tester.AddAttribute("output_qk", static_cast(is_cross_attn)); + + // Inputs and outputs + auto query = CreateRandom(batch_size * 1 * hidden_size); + tester.AddInput("query", {batch_size, 1, hidden_size}, query); + + if (is_cross_attn) { + auto key = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + std::vector reordered_key; + if (use_cuda) { + reordered_key = ReorderKVCache(key, batch_size, num_heads, + kv_sequence_length, head_size, kv_sequence_length, false); + } + auto value = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + tester.AddInput("key", {batch_size, num_heads, kv_sequence_length, head_size}, (use_cuda ? reordered_key : key)); + tester.AddInput("value", {batch_size, num_heads, kv_sequence_length, head_size}, + CreateRandom(batch_size * num_heads * kv_sequence_length * head_size)); + + const std::vector mask_index_dims = {batch_size, kv_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, kv_sequence_length}, mask_index); + + // Calculate Softmax(Q * K^T + (Optional) mask) * V + std::vector empty_attention_bias; + auto output_qk = CalculateOutputQK(query, key, mask_index, empty_attention_bias, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + std::vector output_qk_float(output_qk.size()); + for (size_t i = 0; i < output_qk.size(); ++i) { + output_qk_float[i] = static_cast(output_qk[i]); + } + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, kv_sequence_length); + auto output = CalculateOutput(softmax, value, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOptionalOutputEdge(); // optional present_key + tester.AddOptionalOutputEdge(); // optional present_value + tester.AddOutput("qk", {batch_size, num_heads, 1, kv_sequence_length}, output_qk_float); + } else { + int max_sequence_length = past_sequence_length + 10; + int total_sequence_length = past_sequence_length + 1; + + auto key = CreateRandom(batch_size * hidden_size); + auto value = CreateRandom(batch_size * hidden_size); + tester.AddInput("key", {batch_size, 1, hidden_size}, key); + tester.AddInput("value", {batch_size, 1, hidden_size}, value); + + const std::vector mask_index_dims = {batch_size, total_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, total_sequence_length}, mask_index); + std::vector attention_bias_dims = {1, 1, 1, total_sequence_length}; + auto attention_bias_float = random.Gaussian(attention_bias_dims, 0.0f, 0.3f); + std::vector attention_bias(attention_bias_float.size()); + for (size_t i = 0; i < attention_bias.size(); ++i) { + attention_bias[i] = static_cast(attention_bias_float[i]); + } + tester.AddInput("attention_bias", {1, 1, 1, total_sequence_length}, attention_bias); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + auto past_key = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); + auto past_value = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + std::vector reordered_past_key; // For CUDA, we need to reorder past key + if (use_cuda) { + reordered_past_key = ReorderKVCache(past_key, batch_size, num_heads, + past_sequence_length, head_size, max_sequence_length, false); + } - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + tester.AddInput("past_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? reordered_past_key : past_key)); + tester.AddInput("past_value", {batch_size, num_heads, max_sequence_length, head_size}, past_value); + + // merge past key and value with current key and value + auto merged_key = MergePast(past_key, key, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + std::vector merged_reordered_key; + if (use_cuda) { + merged_reordered_key = MergeReorderedKVCacheWithK(reordered_past_key, key.data(), batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size, false); + } + auto merged_value = MergePast(past_value, value, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + + tester.AddInput("past_sequence_length", {1}, {past_sequence_length}); + + std::vector mod_merged_key, mod_merged_value; + if (beam_width > 1) { + tester.AddInput("beam_width", {1}, {beam_width}); + + const std::vector cache_indir_dims = {batch_size, beam_width, max_sequence_length}; + auto value_candidates = ValueRange(beam_width); + auto cache_indir = generator.Discrete(cache_indir_dims, value_candidates); + tester.AddInput("cache_indirection", cache_indir_dims, cache_indir); + + // Modify merged_key and merged_value according to cache_indirection + mod_merged_key = ReorderKVByCacheIndirection(merged_key, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + mod_merged_value = ReorderKVByCacheIndirection(merged_value, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + } - auto k_merged = pair.first; - auto k_transpose = pair.second; + // Calculate Softmax(Q * K^T + (Optional) mask) * V + auto output_qk = CalculateOutputQK(query, (beam_width > 1 ? mod_merged_key : merged_key), + mask_index, attention_bias, + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, total_sequence_length); + auto output = CalculateOutput(softmax, (beam_width > 1 ? mod_merged_value : merged_value), + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOutput("present_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? merged_reordered_key : merged_key)); + tester.AddOutput("present_value", {batch_size, num_heads, max_sequence_length, head_size}, merged_value); + } - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + if (std::is_same::value) { + tester.SetOutputTolerance(0.02f); + } else { + tester.SetOutputTolerance(0.0001f, 0.0001f); + } - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + { + std::vector> execution_providers; + if (use_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } else { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +#ifdef USE_CUDA - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { + TestDecoderMaskedSelfAttention(); +} - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { + TestDecoderMaskedSelfAttention(); +} - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(); +} - // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(); +} - tester.SetOutputTolerance(0.005f); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; +#endif - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ true, /* use_cuda = */ false); } -#endif +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false, /* use_cuda = */ false); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc index ea6f93a273055..4754f3a520694 100644 --- a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc @@ -11,12 +11,12 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { +TEST(DynamicTimeWarping, simple) { #ifdef USE_CUDA - -TEST(DynamicTimeWarp, simple) { if (NeedSkipIfCudaArchLowerThan(530)) { return; } +#endif std::vector X = { 3.0f, @@ -113,11 +113,12 @@ TEST(DynamicTimeWarp, simple) { tester.AddOutput("output", {2, 12}, Y); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#endif + execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97c..a7d751f4472fc 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA, ROCm and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +#endif +// CUDA and ROCm only for BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 655c4951f262d..52e67bf0616d1 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -120,7 +120,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { @@ -134,7 +134,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { @@ -151,6 +151,20 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -178,7 +192,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider}); + kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { @@ -193,34 +207,82 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), is_initializer); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), is_initializer); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL, OpenVINO and NNAPI don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); + }; + run_test(false); + run_test(true); +} + +template +class LayerNormTest : public ::testing::Test { +}; + +using LayerNormTestTypes = ::testing::Types; +TYPED_TEST_SUITE(LayerNormTest, LayerNormTestTypes); + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); std::vector dims{1, 3, 2}; test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); - test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); - test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), true); test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider}); } // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. -TEST(LayerNormTest, LayerNorm17_float) { - OpTester test("LayerNormalization", 17); - test.AddAttribute("epsilon", 1e-05f); +TYPED_TEST(LayerNormTest, LayerNorm17_opset) { + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, GetTypedArray({1.0f, 1.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + if (std::is_same::value) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}, + nullptr, &execution_providers); + } else { + test.Run(); + } + }; + // Execution provider entry invalid. + // when other EPs support layer-norm fp16, this test should be updated to include them. + if (std::is_same::value) { +#if !defined(COREML_ENABLE_MLPROGRAM) + return; +#endif + } - std::vector dims{1, 2, 3}; - test.AddInput("x", dims, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - test.AddInput("gamma", {3}, {1.0f, 1.0f, 1.0f}); - test.AddOutput("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}); - test.Run(); + run_test(false); + run_test(true); } TEST(LayerNormTest, LayerNorm17_double) { diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 438a1100ca95c..46082e1b0cd31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -6,7 +6,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; @@ -65,8 +65,8 @@ static void TestLayerNorm(const std::vector& x_dims, std::vector Y_data = FillZeros(n_x_m_dims); test.AddOutput("output", n_x_m_dims, Y_data); -#ifndef USE_DML - // DML doesn't support more than one output for these ops yet +#if !defined(USE_DML) && !defined(USE_WEBGPU) + // DML and WebGPU don't support more than one output for these ops yet const std::vector& stats_dims = keep_dims ? n_and_ones_dims : n_dims; std::vector mean_data = FillZeros(stats_dims); std::vector var_data = FillZeros(stats_dims); @@ -84,6 +84,8 @@ static void TestLayerNorm(const std::vector& x_dims, test.CompareWithCPU(kRocmExecutionProvider); #elif USE_DML test.CompareWithCPU(kDmlExecutionProvider); +#elif USE_WEBGPU + test.CompareWithCPU(kWebGpuExecutionProvider); #endif } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 8138829b057f2..eebe9197573c6 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -82,6 +82,7 @@ struct TestOptions { bool has_bias{false}; std::optional output_abs_error{}; + std::optional output_rel_error{}; }; std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { @@ -253,6 +254,10 @@ void RunTest(const TestOptions& opts, test.SetOutputAbsErr("Y", *opts.output_abs_error); } + if (opts.output_rel_error.has_value()) { + test.SetOutputRelErr("Y", *opts.output_rel_error); + } + if (!explicit_eps.empty()) { test.ConfigEps(std::move(explicit_eps)); } @@ -271,10 +276,10 @@ void TestMatMulNBitsTyped() { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; - } else { - if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.01f; - } + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.055f; + base_opts.output_rel_error = 0.02f; } { @@ -288,7 +293,7 @@ void TestMatMulNBitsTyped() { RunTest(opts); } -#if !defined(USE_DML) +#if !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -319,7 +324,7 @@ void TestMatMulNBitsTyped() { opts.has_zero_point = true, opts.zp_is_4bit = false; RunTest(opts); } -#endif // !defined(USE_DML) +#endif // !defined(USE_DML) && !defined(USE_WEBGPU) } TEST(MatMulNBits, Float32_Accuracy0) { @@ -387,48 +392,48 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); } -#ifdef MLAS_TARGET_AMD64_IX86 +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) #if !defined(USE_DML) // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16_Accuracy2) { + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + TEST(MatMulNBits, Float16_Accuracy0) { TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); -} - TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -458,7 +463,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -493,6 +498,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); #endif +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif RunTest(opts, std::move(execution_providers)); } else { @@ -537,6 +545,9 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; +#elif USE_WEBGPU + // See Intel A770 to pass these tests with an absolute error of 0.08. + float abs_error = 0.08f; #else float abs_error = 0.05f; #endif diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 1d167b5dffdb5..6b6799d73fb56 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -49,6 +49,7 @@ static void RunMultiHeadAttentionTest( bool use_float16 = false, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, // not supported in rocm right now. bool disable_dml = false) { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); @@ -59,6 +60,7 @@ static void RunMultiHeadAttentionTest( bool enable_rocm = (nullptr != DefaultRocmExecutionProvider(/*test_tunable_op=*/true).get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; if (enable_rocm && !use_float16) { LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests"; @@ -70,7 +72,7 @@ static void RunMultiHeadAttentionTest( enable_rocm = false; } - if (enable_cpu || enable_cuda || enable_rocm || enable_dml) { + if (enable_cpu || enable_cuda || enable_rocm || enable_dml || enable_webgpu) { OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,6 +303,7 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { @@ -309,7 +318,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -325,7 +335,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -341,7 +352,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -358,7 +370,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } #endif @@ -376,7 +389,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +406,30 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +enum RunMultiHeadAttentionTestToggles : uint32_t { + DISABLE_NONE = 0, + DISABLE_CPU = 1 << 0, + DISABLE_CUDA = 1 << 1, + DISABLE_WEBGPU = 1 << 2, +}; +inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline RunMultiHeadAttentionTestToggles operator&(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +static void RunMultiHeadAttentionTests(AttentionTestData& data, + RunMultiHeadAttentionTestToggles toggles = DISABLE_NONE) { + bool disable_cpu = toggles & DISABLE_CPU; + bool disable_cuda = toggles & DISABLE_CUDA; + bool disable_webgpu = toggles & DISABLE_WEBGPU; + if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +440,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +453,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +464,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +476,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +486,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +497,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +508,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +517,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -503,40 +536,40 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } // This tests qk_head_size != v_head_size @@ -561,7 +594,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -571,7 +604,7 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } #endif @@ -579,27 +612,27 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { @@ -609,10 +642,10 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) RunMultiHeadAttentionTests(data); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 89552da58b938..0e964cf64fbbd 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -67,6 +67,7 @@ static void RunTest( : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); if (enable_cuda && !disable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -74,9 +75,12 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (tensor_type == TensorType::kFloat && !disable_cpu) { + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } if (execution_providers.size() == 0) { // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) return; @@ -135,28 +139,8 @@ static void RunTests(const std::vector& input_data, int max_sequence_length = 0, int64_t interleaved = 0, int64_t is_packed_batching = 0, - bool use_float16 = true, - bool disable_dml = false) { - // FP32 test for CPU - RunTest(input_data, - position_ids, - cos_cache, - sin_cache, - output_data, - batch_size, - sequence_length, - head_size, - rotary_embedding_dim, - num_heads, - max_sequence_length, - interleaved, - is_packed_batching, - TensorType::kFloat, - false, /* disable_cpu */ - true, /* disable_cuda */ - true /* disable_dml */); - - // FP32 test for CUDA and DML + bool use_float16 = true) { + // FP32 test for CPU, CUDA and DML RunTest(input_data, position_ids, cos_cache, @@ -173,9 +157,9 @@ static void RunTests(const std::vector& input_data, TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ - disable_dml || false /* disable_dml */); + false /* disable_dml */); - // FP16 test for CUDA and DML + // FP16 test for CPU, CUDA and DML if (use_float16) { RunTest(input_data, position_ids, @@ -191,26 +175,9 @@ static void RunTests(const std::vector& input_data, interleaved, is_packed_batching, TensorType::kFloat16, - true, /* disable_cpu */ + false, /* disable_cpu */ false, /* disable_cuda*/ - disable_dml || false /* disable_dml */); - - // RunTest(input_data, - // position_ids, - // cos_cache, - // sin_cache, - // output_data, - // batch_size, - // sequence_length, - // head_size, - // rotary_embedding_dim, - // num_heads, - // max_sequence_length, - // interleaved, - // TensorType::kBFloat16, - // true, /* disable_cpu */ - // false, /* disable_cuda*/ - // false /* disable_dml */); + false /* disable_dml */); } } @@ -743,9 +710,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { num_heads, max_sequence_length, interleaved, - 0, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 0, // is_packed_batching + true /*use_fp16*/); } TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { @@ -785,9 +751,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_B num_heads, max_sequence_length, interleaved, - 1, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 1, // is_packed_batching + true /*use_fp16*/); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index edf9064bb43c9..4e8d1b9f016f0 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -62,6 +62,8 @@ static void RunOneTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); @@ -95,10 +97,14 @@ static void RunOneTest( if (cpu_ep != nullptr) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr) { + rocm_ep != nullptr || + webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -132,7 +138,9 @@ static void RunOneTest( ToFloat16(sum_output_data)); } - if (dml_ep != nullptr) { + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); @@ -186,6 +194,32 @@ static void RunTest( } } +TEST(SkipLayerNormTest, SkipLayerNormPrePack) { + OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 2; + std::vector input_skip_output_dims = {batch_size, sequence_length, hidden_size}; + std::vector gamma_beta_bias_dims = {hidden_size}; + test.AddInput("x", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("skip", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("gamma", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddInput("beta", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddOutput("output", input_skip_output_dims, ToFloat16({ + 1.f, + 1.f, + 1.f, + 1.f, + })); + + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}); +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0; diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc index 81c8641f450f6..bc2ff5f4f724d 100644 --- a/onnxruntime/test/contrib_ops/tensor_op_test.cc +++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc @@ -205,12 +205,12 @@ TEST(MVNContribOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) { MeanVarianceNormalizationPerChannel(false, true); } -#ifdef USE_CUDA - TEST(UnfoldTensorOpTest, LastDim) { +#ifdef USE_CUDA if (NeedSkipIfCudaArchLowerThan(530)) { return; } +#endif std::vector X = { 1.0f, 2.0f, 3.0f, 4.0f, @@ -229,7 +229,10 @@ TEST(UnfoldTensorOpTest, LastDim) { tester.AddOutput("output", {3, 2, 3}, output); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#endif + execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -238,13 +241,13 @@ TEST(UnfoldTensorOpTest, NormalDim) { return; } - std::vector X = { + std::vector X = { 1, 2, 3, 4, 2, 2, 3, 4, 3, 2, 3, 4, 4, 6, 7, 8, 5, 6, 7, 8, 6, 6, 7, 8, 6, 7, 8, 9, 7, 7, 8, 9, 8, 7, 8, 9, 9, 7, 8, 9, 10, 7, 8, 9, 11, 7, 8, 9}; - std::vector output = { + std::vector output = { 1, 2, 3, 2, 2, 2, 3, 3, 3, @@ -269,15 +272,16 @@ TEST(UnfoldTensorOpTest, NormalDim) { tester.AddAttribute("dim", 1LL); tester.AddAttribute("size", 3LL); tester.AddAttribute("step", 2LL); - tester.AddInput("input", {2, 6, 4}, X); - tester.AddOutput("output", {2, 2, 4, 3}, output); + tester.AddInput("input", {2, 6, 4}, X); + tester.AddOutput("output", {2, 2, 4, 3}, output); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#endif + execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 0105e90b5a24a..eaebac177ca91 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -251,6 +251,7 @@ class PlannerTest : public ::testing::Test { void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, std::unordered_map>& kernel_create_info_map) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); auto info = std::make_unique( @@ -260,7 +261,7 @@ class PlannerTest : public ::testing::Test { op_kernel_infos_.push_back(std::move(info)); const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider, - kernel_type_str_resolver)) { + kernel_type_str_resolver, logger)) { ASSERT_STATUS_OK(reg->Register( KernelCreateInfo(std::make_unique(kernel_def), [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { @@ -270,7 +271,7 @@ class PlannerTest : public ::testing::Test { } const KernelCreateInfo* kci; - ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci)); + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, logger, &kci)); kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } @@ -282,7 +283,8 @@ class PlannerTest : public ::testing::Test { } } - void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { + void CreatePlan(const std::vector& outer_scope_node_args = {}, + bool invoke_createPlan_explicityly = true) { state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 8961058628490..fa6c4966d6953 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -3,6 +3,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/allocator_utils.h" #include "test_utils.h" #include "gtest/gtest.h" @@ -15,12 +16,10 @@ TEST(AllocatorTest, CPUAllocatorTest) { ASSERT_STREQ(cpu_arena->Info().name, CPU); EXPECT_EQ(cpu_arena->Info().id, 0); - // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) - EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtArenaAllocator); -#else - EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtDeviceAllocator); -#endif + const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() + ? OrtAllocatorType::OrtArenaAllocator + : OrtAllocatorType::OrtDeviceAllocator; + EXPECT_EQ(cpu_arena->Info().alloc_type, expected_allocator_type); size_t size = 1024; auto bytes = cpu_arena->Alloc(size); diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index fa3545ef27d72..180a75a64c10e 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -580,13 +580,7 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { // myfun is not removed because it was claimed by InternalTestingEP model_proto = session_object.GetModel().ToProto(); -#ifdef USE_TVM - // TVM EP takes the whole graph and optimizes it within its own framework. - // It does not retain the original graph. - ASSERT_EQ(0, model_proto.functions_size()); -#else ASSERT_EQ(1, model_proto.functions_size()); -#endif } TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 8b230db351edc..740c566794f15 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -693,6 +693,9 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #endif #ifdef USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); +#endif +#ifdef USE_WEBGPU + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultWebGpuExecutionProvider())); #endif ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); ASSERT_STATUS_OK(session_object.Initialize()); @@ -719,7 +722,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { ASSERT_TRUE(lines[size - 1].find("]") != string::npos); std::vector tags = {"pid", "dur", "ts", "ph", "X", "name", "args"}; - bool has_api_info = false; + [[maybe_unused]] bool has_api_info = false; for (size_t i = 1; i < size - 1; ++i) { for (auto& s : tags) { ASSERT_TRUE(lines[i].find(s) != string::npos); @@ -730,14 +733,16 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #ifdef USE_ROCM has_api_info = has_api_info || lines[i].find("Api") != string::npos && lines[i].find("hipLaunch") != string::npos; +#endif +#ifdef USE_WEBGPU + has_api_info = has_api_info || lines[i].find("Api") != string::npos; #endif } } -#if defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING) +// Note that the apple device is a paravirtual device which may not support webgpu timestamp query. So skip the check on it. +#if (defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING)) || (defined(USE_WEBGPU) && !defined(__APPLE__)) ASSERT_TRUE(has_api_info); -#else - ASSERT_TRUE(has_api_info || true); #endif } @@ -765,7 +770,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { while (std::getline(profile, line)) { if (count == 0) { ASSERT_TRUE(line.find("[") != string::npos); - } else if (count <= 5) { + } else if (count <= 3) { for (auto& s : tags) { ASSERT_TRUE(line.find(s) != string::npos); } @@ -774,7 +779,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { } if (count == 1) { - ASSERT_TRUE(line.find("mul_1_fence_before") != string::npos); + ASSERT_TRUE(line.find("mul_1_kernel_time") != string::npos); } count++; } @@ -806,6 +811,47 @@ TEST(InferenceSessionTests, CheckRunProfilerStartTime) { ASSERT_TRUE(before_start_time <= profiling_start_time && profiling_start_time <= after_start_time); } +TEST(InferenceSessionTests, CheckRunProfilerWithOptionalValues) { + // Test whether the profiler can work on model with optional values + SessionOptions so; + + so.session_logid = "CheckRunProfiler"; + so.enable_profiling = true; + so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); + + InferenceSession session_object(so, GetEnvironment()); + ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/relu_with_optional.onnx"))); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + + // prepare inputs + std::vector dims_x = {1}; + std::vector values_x = {-4}; + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_x, values_x, &ml_value); + NameMLValMap feeds; + feeds.insert(std::make_pair("input", ml_value)); + + // prepare outputs + std::vector output_names; + output_names.push_back("output"); + std::vector fetches; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {1}; + std::vector expected_values_y = {0}; + + // Now run + common::Status st = session_object.Run(run_options, feeds, output_names, &fetches); + if (!st.IsOK()) { + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + } + ASSERT_TRUE(st.IsOK()); + VerifyOutputs(fetches.at(0).Get(), expected_dims_y, expected_values_y); +} + TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { SessionOptions session_options; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b94d24a1b180b..3e694020f796b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -5,6 +5,7 @@ #include #include "asserts.h" +#include "core/framework/allocator_utils.h" #include "core/framework/execution_providers.h" #include "core/framework/graph_partitioner.h" #include "core/framework/kernel_registry.h" @@ -216,10 +217,12 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { // Test that we allocate memory for an initializer from non-arena memory even if we provide an arena-based allocator // if the relevant session option config flag is set -// For this test we need to enable the arena-based allocator which is not supported on x86 builds, so -// enable this test only on x64 builds -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { + // For this test we need to enable the arena-based allocator. + if (!DoesCpuAllocatorSupportArenaUsage()) { + GTEST_SKIP() << "CPU allocator does not support arena usage."; + } + AllocatorPtr cpu_allocator = std::make_shared(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory { @@ -348,8 +351,6 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { } } -#endif - INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list)); #ifndef ENABLE_TRAINING_CORE diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 9202543b75a6f..fba099f9c55b3 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/framework/allocator_utils.h" #include "test_utils.h" #include "gmock/gmock.h" @@ -137,12 +138,10 @@ TEST(TensorTest, EmptyTensorTest) { ASSERT_STREQ(location.name, CPU); EXPECT_EQ(location.id, 0); - // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) - EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtArenaAllocator); -#else - EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtDeviceAllocator); -#endif + const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() + ? OrtAllocatorType::OrtArenaAllocator + : OrtAllocatorType::OrtDeviceAllocator; + EXPECT_EQ(location.alloc_type, expected_allocator_type); } TEST(TensorTest, StringTensorTest) { diff --git a/onnxruntime/test/logging_apis/test_logging_apis.cc b/onnxruntime/test/logging_apis/test_logging_apis.cc index d72c47493d800..b98e5c34b4e1d 100644 --- a/onnxruntime/test/logging_apis/test_logging_apis.cc +++ b/onnxruntime/test/logging_apis/test_logging_apis.cc @@ -359,12 +359,16 @@ TEST_F(MockCAPITestsFixture, CppLogMacroBypassCApiCall) { #undef TEST_MAIN #define TEST_MAIN main_no_link_ // there is a UI test app for iOS. -// IOS tests require this function to be defined. +// iOS tests require ortenv_setup() and ortenv_teardown() to be defined. // See onnxruntime/test/xctest/xcgtest.mm -void ortenv_setup() { +extern "C" void ortenv_setup() { // Do nothing. These logging tests do not require an env to be setup initially. } +extern "C" void ortenv_teardown() { + // Do nothing. +} + #endif // TARGET_OS_SIMULATOR || TARGET_OS_IOS #endif // defined(__APPLE__) diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_cast.cpp similarity index 100% rename from onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp rename to onnxruntime/test/mlas/bench/bench_cast.cpp diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp similarity index 53% rename from onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp rename to onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 71db7d81075b5..64d229889214b 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "benchmark/benchmark.h" @@ -16,16 +17,16 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); +template +void RunQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { + if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { + state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); return; } @@ -43,40 +44,40 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + const auto A = RandomVectorUniform(M * K, AType(-1.0f), AType(1.0f)); + const auto B = RandomVectorUniform(K * N, AType(-1.0f), AType(1.0f)); - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); + const auto Bias = HasBias ? RandomVectorUniform(N, AType(-1.0f), AType(1.0f)) : std::vector(); - std::vector C(static_cast(M * N)); + std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); bool has_zp_input = !Symmetric; - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), + tp.get()); } - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -92,15 +93,15 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } -template -void SQNBITGEMM(benchmark::State& state) { +template +void QNBITGEMM(benchmark::State& state) { using onnxruntime::narrow; const auto BlkLen = narrow(state.range(0)); @@ -110,46 +111,50 @@ void SQNBITGEMM(benchmark::State& state) { const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); + const auto ComputeType = static_cast(state.range(7)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +template +static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + {128}, // BlkLen + {1, 4096}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + std::is_same_v + ? std::vector{int64_t{HQNBIT_CompFp16}} + : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { +template +void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", + static_cast(SQNBIT_CompFp32)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen @@ -159,4 +164,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); +BENCHMARK(QNBITGEMM_ENV)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h index f96dd5c673b3d..78789ef1cbc1a 100644 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ b/onnxruntime/test/mlas/bench/bench_util.h @@ -8,8 +8,12 @@ #include #include +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" + template -std::vector RandomVectorUniform( +typename std::enable_if_t, std::vector> +RandomVectorUniform( size_t N, ElementType min_value = std::numeric_limits::lowest(), ElementType max_value = std::numeric_limits::max()) { @@ -26,6 +30,25 @@ std::vector RandomVectorUniform( return r; } +template +typename std::enable_if_t, std::vector> +RandomVectorUniform( + size_t N, + ElementType min_value, + ElementType max_value) { + if (min_value.ToFloat() >= max_value.ToFloat()) { + return std::vector(N, min_value); + } + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(min_value.ToFloat(), max_value.ToFloat()); + + std::vector r(N); + for (size_t i = 0; i < N; i++) { + r[i] = ElementType(distribution(generator)); + } + return r; +} + std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp new file mode 100644 index 0000000000000..b598c20e29280 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp @@ -0,0 +1,501 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hqnbitgemm_neon.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/qnbitgemm.h" +#include "mlas_qnbit.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + MatrixGuardBuffer fp32Buffer_; + MatrixGuardBuffer fp16Buffer_; + + template + void TestFp16ToFp32() { + const auto* src = fp16Buffer_.GetFilledBuffer(count, [](unsigned short* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = static_cast(i); + } + }); + auto* dest = fp32Buffer_.GetBuffer(count, true); + + MlasCastF16ToF32KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + template + void TestFp32ToFp16() { + const auto* src = fp32Buffer_.GetFilledBuffer(count, [](float* p, size_t size) { + for (size_t i = 0; i < size; i++) { + p[i] = static_cast(i) + 0.125f; + } + }); + auto* dest = fp16Buffer_.GetBuffer(count, true); + + MlasCastF32ToF16KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32<(1 << 16)>(); + TestFp16ToFp32<1>(); + TestFp16ToFp32<4>(); + TestFp16ToFp32<7>(); + TestFp32ToFp16<(1 << 16)>(); + TestFp32ToFp16<3>(); + TestFp32ToFp16<4>(); + TestFp32ToFp16<6>(); + } +}; + +class MlasNeonFp16PrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void Transpose8x8(const uint8_t* src, size_t n, size_t k, uint8_t* dst) { + for (size_t c = 0; c < 8; c++) { + for (size_t r = 0; r < 8; r++) { + size_t i = (n + c) * Ldb + r + k; + size_t j = n * Ldb + (r + k) * 8 + c; + dst[j] = src[i]; + } + } + } + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + MLAS_FORCEINLINE + void PrepackSlice(const uint8_t* src, size_t j, uint8_t* dst) { + for (size_t i = 0; i < 8; i++) { + uint8_t v0 = GetInt4(src[j + (i >> 1)], i); + uint8_t v1 = GetInt4(src[j + ((8 + i) >> 1)], i + 8); + dst[j + i] = v0 | (v1 << 4); + } + } + + template + MLAS_FORCEINLINE void Prepack(const uint8_t* src, uint8_t* dst) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t k = 0; k < Ldb; k += 8) { + Transpose8x8(src, n, k, dst); + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < Ldb; k += 8) { + PrepackSlice(src, n * Ldb + k, dst); + } + } + } + + template + MLAS_FORCEINLINE void Check(const uint8_t* packed, const uint8_t* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; i += 2) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(packed[n * Ldb + (i >> 1) * 8 + j], ref[n * Ldb + (i >> 1) * 8 + j]) + << " seed " << seed_ + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; i += 2) { + ASSERT_EQ(packed[n * Ldb + (i >> 1)], ref[n * Ldb + (i >> 1)]) + << " seed " << seed_ + << " n " << n << " i " << i; + } + } + } + + template + void TestPrepack() { + constexpr size_t Bits = 4; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t BufferSize = N * Ldb; + auto InitializeBuffer = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BufferSize, InitializeBuffer); + auto* packed = packed_.GetBuffer(BufferSize, true); + auto* ref = ref_.GetBuffer(BufferSize, true); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::HQNBIT_CompFp16, input, packed, + nullptr, false, nullptr, nullptr); + Prepack(input, ref); + Check(packed, ref); + } + + public: + MlasNeonFp16PrepackTest() + : seed_(19287), gen_(seed_), distrib_(0, 255) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16Prepack"; + } + + void ExecuteShort(void) override { + TestPrepack<1, 1, 16>(); + TestPrepack<1, 15, 16>(); + TestPrepack<1, 31, 16>(); + TestPrepack<8, 1, 16>(); + TestPrepack<8, 16, 16>(); + TestPrepack<9, 31, 16>(); + TestPrepack<9, 33, 32>(); + TestPrepack<15, 33, 16>(); + TestPrepack<17, 67, 16>(); + TestPrepack<17, 96, 128>(); + TestPrepack<263, 263, 16>(); + } +}; + +class MlasNeonFp16DequantBTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + std::uniform_real_distribution _distribFp; + MatrixGuardBuffer input_, zero_points_; + MatrixGuardBuffer dequant_, ref_, scales_; + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + template + void DequantB(const uint8_t* src, MLAS_FP16* dst, const MLAS_FP16* scales, const uint8_t* zero_points) { + constexpr size_t blkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ld_src = (blkNum * BlkLen + 1) / 2; + constexpr size_t ld_dst = blkNum * BlkLen; + constexpr size_t ld_zp = (blkNum + 1) / 2; + size_t n = 0; + for (; n + 8 <= N; n += 8) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + for (size_t i = 0; i < BlkLen; i += 2, i_dst += 8) { + for (size_t j = 0; j < 8; ++j, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp + ld_zp * j], blk) : 8); + float scale = scales[i_scale + blkNum * j]; + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + + for (; n < N; ++n) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp], blk) : 8); + float scale = scales[i_scale]; + for (size_t i = 0; i < BlkLen; i += 16, i_dst += 8) { + for (size_t j = 0; j < 16; j += 2, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = std::abs(v0.ToFloat()), f1 = std::abs(v1.ToFloat()); + return std::abs(f0 - f1) <= f1 * rtol + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < 8; ++j) { + size_t idx = n * Ldb + i * 8 + j; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; ++i) { + size_t idx = n * Ldb + i; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i; + } + } + } + + template + void TestDequant() { + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t BCount = BlkNum * BlkLen * N; + constexpr size_t ScaleCount = N * BlkNum; + constexpr size_t ZpSize = N * ((BlkNum + 1) / 2); + + auto InitializeBuffer_i8 = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + auto InitializeBuffer_fp16 = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(_distribFp(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BCount / 2, InitializeBuffer_i8); + const auto* zero_points = zero_points_.GetFilledBuffer(ZpSize, InitializeBuffer_i8); + auto* dequant = dequant_.GetBuffer(BCount); + auto* ref = ref_.GetBuffer(BCount); + const auto* scales = scales_.GetFilledBuffer(ScaleCount, InitializeBuffer_fp16); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant, reinterpret_cast(input), scales, + UseZeroPoints ? reinterpret_cast(zero_points) : nullptr, + N, K, BlkNum); + DequantB(input, ref, scales, zero_points); + Check(dequant, ref); + } + + public: + MlasNeonFp16DequantBTest() + : seed_(19287), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16DequantB"; + } + + void ExecuteShort(void) override { + TestDequant<1, 1, 16, false>(); + TestDequant<1, 1, 16, true>(); + TestDequant<1, 15, 16, false>(); + TestDequant<1, 15, 16, true>(); + TestDequant<1, 31, 16, false>(); + TestDequant<1, 31, 16, true>(); + TestDequant<8, 1, 16, false>(); + TestDequant<8, 1, 16, true>(); + TestDequant<8, 16, 16, false>(); + TestDequant<8, 16, 16, true>(); + TestDequant<9, 31, 16, false>(); + TestDequant<9, 31, 16, true>(); + TestDequant<9, 33, 32, false>(); + TestDequant<9, 33, 32, true>(); + TestDequant<15, 33, 16, false>(); + TestDequant<15, 33, 16, true>(); + TestDequant<17, 67, 16, false>(); + TestDequant<17, 67, 16, true>(); + TestDequant<17, 96, 128, false>(); + TestDequant<17, 96, 128, true>(); + TestDequant<263, 263, 16, false>(); + TestDequant<263, 263, 16, true>(); + } +}; + +class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + MatrixGuardBuffer A_, B_, C_, ref_, bias_; + + MLAS_FORCEINLINE + void InitializeBuffer(MLAS_FP16* buffer, float min, float max, size_t count) { + std::uniform_real_distribution distrib(min, max); + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib(gen_)); + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + float GetBVal(const MLAS_FP16* B, size_t n, size_t k) { + size_t i; + if ((N & (~7)) > n) { + size_t full8 = n & (~7); + i = full8 * ldb + 8 * k + (n - full8); + } else { + i = n * ldb + k; + } + return B[i].ToFloat(); + } + + template + void MatMul(const MLAS_FP16* A, const MLAS_FP16* B, const MLAS_FP16* bias, MLAS_FP16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = UseBias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * K + k].ToFloat(); + float b = GetBVal(B, n, k); + accu = accu + a * b; + } + C[m * N + n] = MLAS_FP16(accu); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * Ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestHQ4BitGemmKernel() { + static_assert(M <= 2); + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkNum * BlkLen; + + const auto* A = A_.GetFilledBuffer(M * K, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + const auto* B = B_.GetFilledBuffer(ldb * N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + auto* bias = bias_.GetFilledBuffer(N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -5.0f, 5.0f, t); + }); + + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + A, B, UseBias ? bias : nullptr, C, M, N, K, K, ldb, N); + + MatMul(A, B, bias, ref); + Check(C, ref); + } + + public: + MlasNeonFp16HQ4BitGemmKernelTest() + : seed_(19287), gen_(seed_) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16HQ4BitGemmKernel"; + } + + template + void ExecuteShort_T(void) { + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + } + + void ExecuteShort(void) override { + ExecuteShort_T<1>(); + ExecuteShort_T<2>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + if (GetMlasPlatform().QNBitGemmDispatch) { + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmPackQuantBData) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + } + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 0710981fa17c6..e22018ae2877f 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,11 +18,11 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { - case CompFp32: + case SQNBIT_CompFp32: return "Fp32"; - case CompInt8: + case SQNBIT_CompInt8: return "Int8"; default: return "unknown"; @@ -63,16 +63,16 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; + MLAS_QNBIT_GEMM_DATA_PARAMS params; params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; params.ldc = ldc; #ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { params.QuantBDataWorkspace = PackedQuantBDataWorkspace; } #endif @@ -81,7 +81,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -201,7 +201,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -265,19 +265,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, + GetMlasThreadPool()); } CallGemm(M, N, K, @@ -289,9 +289,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { ComputeType, Threadpool); - if (ComputeType == CompFp32) { + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -324,7 +324,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -341,11 +341,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture= range ? FillValue - range : FillValue; } }); } diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 45aaca1ceae56..e564443ed8eb0 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -25,7 +25,7 @@ #include "core/common/logging/logging.h" #include "core/common/common.h" #include "core/platform/env.h" -#include "core/platform/ort_mutex.h" +#include #include "core/platform/path_lib.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/allocator.h" @@ -288,12 +288,12 @@ class OnnxTestCase : public ITestCase { private: std::string test_case_name_; mutable std::vector debuginfo_strings_; - mutable onnxruntime::OrtMutex m_; + mutable std::mutex m_; std::vector test_data_dirs_; std::string GetDatasetDebugInfoString(size_t dataset_id) const override { - std::lock_guard l(m_); + std::lock_guard l(m_); if (dataset_id < debuginfo_strings_.size()) { return debuginfo_strings_[dataset_id]; } @@ -488,7 +488,7 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b, if (st.IsOK()) { // has an all-in-one input file std::ostringstream oss; { - std::lock_guard l(m_); + std::lock_guard l(m_); oss << debuginfo_strings_[id]; } ORT_TRY { @@ -503,7 +503,7 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b, } { - std::lock_guard l(m_); + std::lock_guard l(m_); debuginfo_strings_[id] = oss.str(); } return; @@ -1026,7 +1026,13 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, - {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}}); + {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"qlinearmatmul_2D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_2D_int8_float32", "result diff", {}}, + {"qlinearmatmul_2D_uint8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float32", "result diff", {}}, + {"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}}); // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. diff --git a/onnxruntime/test/onnx/TestResultStat.h b/onnxruntime/test/onnx/TestResultStat.h index 5bfc04c3cd577..0804b1d7a4139 100644 --- a/onnxruntime/test/onnx/TestResultStat.h +++ b/onnxruntime/test/onnx/TestResultStat.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include @@ -26,22 +26,22 @@ class TestResultStat { TestResultStat() : succeeded(0), not_implemented(0), load_model_failed(0), throwed_exception(0), result_differs(0), skipped(0), invalid_graph(0) {} void AddNotImplementedKernels(const std::string& s) { - std::lock_guard l(m_); + std::lock_guard l(m_); not_implemented_kernels.insert(s); } void AddFailedKernels(const std::string& s) { - std::lock_guard l(m_); + std::lock_guard l(m_); failed_kernels.insert(s); } void AddFailedTest(const std::pair& p) { - std::lock_guard l(m_); + std::lock_guard l(m_); failed_test_cases.insert(p); } const std::set>& GetFailedTest() const { - std::lock_guard l(m_); + std::lock_guard l(m_); return failed_test_cases; } @@ -74,7 +74,7 @@ class TestResultStat { } private: - mutable onnxruntime::OrtMutex m_; + mutable std::mutex m_; std::unordered_set not_implemented_kernels; std::unordered_set failed_kernels; std::set> failed_test_cases; // pairs of test name and version diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6d86e4c35af85..99c3e44e13013 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -77,6 +77,8 @@ void usage() { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -452,8 +454,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (ep_context_enable) sf.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - if (disable_ep_context_embed_mode) + if (disable_ep_context_embed_mode) { sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + } else { + sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "1"); + } for (auto& it : session_config_entries) { sf.AddConfigEntry(it.first.c_str(), it.second.c_str()); @@ -587,20 +592,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', -'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); +'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } qnn_options[key] = value; @@ -629,7 +634,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } if (enable_coreml) { #ifdef USE_COREML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0)); + sf.AppendExecutionProvider("CoreML", {}); #else fprintf(stderr, "CoreML is not supported in this build"); return -1; diff --git a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc new file mode 100644 index 0000000000000..f6158d8cbc12b --- /dev/null +++ b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc @@ -0,0 +1,152 @@ +#ifdef _WIN32 + +#include "core/platform/threadpool.h" +#include "core/util/thread_utils.h" +#include + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" +#include "core/framework/data_transfer_manager.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/ort_value_name_idx_map.h" +#include "core/platform/windows/env.h" +#include "core/providers/cpu/nn/layer_norm_impl.h" +#include "core/providers/cpu/cpu_provider_factory.h" +#include "core/providers/cpu/cpu_provider_factory_creator.h" +#include "core/util/thread_utils.h" + +#include "test/onnx/microbenchmark/common.h" + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +using namespace onnxruntime; + +namespace { + +std::vector createMLFloat16Vector(float* vals, int64_t num_elems) { + std::vector fp16vec; + fp16vec.reserve(num_elems); + + for (int64_t i = 0; i < num_elems; i++) { + fp16vec.push_back(MLFloat16(vals[i])); + } + + return fp16vec; +} + +} // namespace + +template +static void BM_LayerNormalization(benchmark::State& state) { + bool simplified = false; + const float epsilon = 1e-05f; + int64_t axis = 1; + + onnxruntime::Node node; + // Required by LayerNormImpl constructor + node.AddAttribute("axis", axis); + node.AddAttribute("epsilon", epsilon); + + KernelDef kernel_def; + std::unique_ptr execution_provider = CPUProviderFactoryCreator::Create(true)->CreateProvider(); + std::unordered_map constant_initialized_tensors; + OrtValueNameIdxMap mlvalue_name_idx_map; + DataTransferManager data_transfer_mgr; + AllocatorMap allocators; + ConfigOptions config_options; + + OpKernelInfo op_kernel_info(node, kernel_def, *execution_provider, constant_initialized_tensors, mlvalue_name_idx_map, + data_transfer_mgr, allocators, config_options); + + LayerNormImpl layer_norm_impl(op_kernel_info); + + const std::vector dims{1, 256, 1024}; + const size_t num_elems = dims[0] * dims[1] * dims[2]; + + TensorShape x_shape(dims); + TensorShape scale_shape(dims); + TensorShape bias_shape(dims); + + const float low = -1.0f; + const float high = 1.0f; + + float* x_float = GenerateArrayWithRandomValue(num_elems, low, high); + float* scale_float = GenerateArrayWithRandomValue(num_elems, 0.1f, high); + float* bias_float = GenerateArrayWithRandomValue(num_elems, low, high); + + std::vector x_MLFloat16 = createMLFloat16Vector(x_float, num_elems); + std::vector scale_MLFloat16 = createMLFloat16Vector(scale_float, num_elems); + std::vector bias_MLFloat16 = createMLFloat16Vector(bias_float, num_elems); + + T* x_data = nullptr; + T* scale_data = nullptr; + T* bias_data = nullptr; + if (std::is_same_v) { + x_data = (T*)x_MLFloat16.data(); + scale_data = (T*)scale_MLFloat16.data(); + bias_data = (T*)bias_MLFloat16.data(); + } else if (std::is_same_v) { + x_data = (T*)x_float; + scale_data = (T*)scale_float; + bias_data = (T*)bias_float; + } + assert(x_data); + + T* Y_data = static_cast(aligned_alloc(num_elems * sizeof(T), 64)); + U* mean_data = static_cast(aligned_alloc(num_elems * sizeof(U), 64)); + U* inv_std_dev_data = static_cast(aligned_alloc(num_elems * sizeof(U), 64)); + + OrtThreadPoolParams tp_params; + tp_params.name = ORT_TSTR("intra-op"); + std::unique_ptr thread_pool = concurrency::CreateThreadPool( + &Env::Default(), tp_params, concurrency::ThreadPoolType::INTRA_OP); + + OrtMemoryInfo memory_info(onnxruntime::CPU, OrtAllocatorType::OrtArenaAllocator); + AllocatorPtr alloc = std::make_shared(memory_info); + for (auto _ : state) { + auto status = layer_norm_impl.ComputeWithoutContext(x_data, + x_shape, + scale_data, + static_cast(scale_shape.Size()), + bias_data, + static_cast(bias_shape.Size()), + Y_data, + mean_data, + inv_std_dev_data, + thread_pool.get(), + axis, + epsilon, + simplified, + alloc); + if (!status.IsOK()) { + std::cout << "ComputeWithoutContext status not OK: " << status.ErrorMessage() << std::endl; + break; + } + } + + aligned_free(x_float); + aligned_free(scale_float); + aligned_free(bias_float); + aligned_free(Y_data); + aligned_free(mean_data); + aligned_free(inv_std_dev_data); +} + +BENCHMARK(BM_LayerNormalization) + ->Arg(1) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_LayerNormalization) + ->Arg(1) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +#endif diff --git a/onnxruntime/test/onnx/onnxruntime_event.h b/onnxruntime/test/onnx/onnxruntime_event.h index b830a9f888edb..a7cfbccad3d8a 100644 --- a/onnxruntime/test/onnx/onnxruntime_event.h +++ b/onnxruntime/test/onnx/onnxruntime_event.h @@ -2,12 +2,12 @@ // Licensed under the MIT License. #include -#include +#include struct OnnxRuntimeEvent { public: - onnxruntime::OrtMutex finish_event_mutex; - onnxruntime::OrtCondVar finish_event_data; + std::mutex finish_event_mutex; + std::condition_variable finish_event_data; bool finished = false; OnnxRuntimeEvent() = default; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 3aec0d5a67e94..2ff0b599beebf 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -831,7 +831,8 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_mapName() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -1764,6 +1765,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { } } +TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; @@ -4675,7 +4705,8 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { // Compare results double per_sample_tolerance = 1e-3; double relative_per_sample_tolerance = 0.0; - auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false); + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], + per_sample_tolerance, relative_per_sample_tolerance, false); EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; } @@ -4684,7 +4715,8 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::make_unique(CPUExecutionProviderInfo()); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4699,7 +4731,8 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; @@ -5859,6 +5892,22 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { std::map op_to_count = CountOpsInGraph(graph); EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); } + +TEST_F(GraphTransformationTests, MatMulIntegerToFloatLargeTensorTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float_large_tensor.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kDmlExecutionProvider); + } + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); +} #endif // USE_DML #endif diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..caa64560426af 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -36,9 +36,11 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +63,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, + disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 81c1a4ace1e33..b306f026b2dfd 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -27,6 +27,7 @@ namespace test { TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); auto& graph = model.MainGraph(); constexpr int tensor_dim = 10; @@ -66,22 +67,21 @@ TEST(OptimizerTest, Basic) { auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); #if !defined(DISABLE_SPARSE_TENSORS) - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [&graph](const std::string& name) -> bool { - return graph.IsSparseInitializer(name); - }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [&graph](const std::string& name) -> bool { + return graph.IsSparseInitializer(name); + }, + logger); #else - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [](std::string const&) { return false; }, + logger); #endif //! defined(DISABLE_SPARSE_TENSORS) std::vector fetch_mlvalue_idxs{info.GetMLValueIndex("out")}; OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); - const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); const ConfigOptions empty_config_options; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..043b92d7ef121 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -11,6 +11,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -3927,6 +3928,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); + const auto& logger = DefaultLoggingManager().DefaultLogger(); SessionOptions so; // We want to keep the graph un-optimized to prevent QDQ transformer to kick in @@ -3961,7 +3963,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check if SelectorManager get a conv qdq group selection as expected { - const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger); ASSERT_FALSE(result.empty()); const auto& qdq_group = result.at(0); ASSERT_EQ(std::vector({0, 1, 2}), qdq_group.dq_nodes); @@ -3976,7 +3978,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -4044,7 +4046,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check SelectorManager will get empty result { - const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer, logger); ASSERT_TRUE(result.empty()); } } @@ -4846,5 +4848,95 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) { } #endif // !defined(DISABLE_CONTRIB_OPS) +TEST(QDQTransformerTests, BiasQuantization_Conv) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 24, 128, 128}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({24, 1, 3, 3}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({24}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.07f, static_cast(0), input_dq_arg, + use_contrib_qdq); + auto& weight_dq_node = builder.AddDequantizeLinearNode(weight_arg, std::vector(24, 0.05f), + std::vector(24, static_cast(0)), + weight_dq_arg, nullptr, use_contrib_qdq); + weight_dq_node.AddAttribute("axis", static_cast(0)); + auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_dq_arg, bias_arg}, {conv_dq_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(24)); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + builder.AddQuantizeLinearNode(conv_dq_arg, 0.14f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["QLinearConv"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + +TEST(QDQTransformerTests, BiasQuantization_Gemm) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({1, 32}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({16, 32}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({16}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(0), input_dq_arg, + use_contrib_qdq); + builder.AddDequantizeLinearNode(weight_arg, 0.26f, static_cast(0), weight_dq_arg, + use_contrib_qdq); + auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg}); + gemm_node.AddAttribute("transB", static_cast(1)); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(69), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 35ba1a3369597..f6fce37322c10 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -22,6 +22,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_utils.h" @@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) { /*opset_version*/ {15, 18}); } +TEST(TransposeOptimizerTests, TestQLinearSoftmax) { + auto build_test_case_1 = [&](ModelTestBuilder& builder) { + auto* input0_arg = MakeInput(builder, std::nullopt, {1, 384, 384, 21}, 0, 255); + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* input_x_scale = builder.MakeScalarInitializer(0.5086354613304138); + auto* input_x_zero_point = builder.MakeScalarInitializer(74); + auto* input_y_scale = builder.MakeScalarInitializer(0.003921568859368563); + auto* input_y_zero_point = builder.MakeScalarInitializer(0); + auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax", + {transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point}, + {qlinearsoftmax_1_out_0}, kMSDomain); + qlinearsoftmax_1.AddAttribute("axis", static_cast(1)); + qlinearsoftmax_1.AddAttribute("opset", static_cast(13)); + auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; + + auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); + }; + + TransformerTester(build_test_case_1, + check_optimized_graph_1, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 13, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + /*transformer*/ nullptr, + /*add_session_options*/ {}, + /*disabled_optimizers*/ {}, + /*ep*/ DefaultCpuExecutionProvider()); +} + TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 42b73ec384cf5..23c3812ebd025 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -24,6 +24,7 @@ #include #include "test_configuration.h" +#include "strings_helper.h" namespace onnxruntime { namespace perftest { @@ -36,7 +37,7 @@ namespace perftest { "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. \n" "\t-M: Disable memory pattern.\n" "\t-A: Disable memory arena\n" - "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" + "\t-I: Generate tensor input binding. Free dimensions are treated as 1 unless overridden using -f.\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai|webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " @@ -76,11 +77,10 @@ namespace perftest { "\n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" @@ -99,6 +99,9 @@ namespace perftest { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" @@ -128,8 +131,14 @@ namespace perftest { "\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" "\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" "\n" - "\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM]: Create an ML Program model instead of Neural Network.\n" - "\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n" + "\t [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network. Default is NeuralNetwork.\n" + "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" + "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n" + "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n" + "\t [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" + "\t [CoreML only] [ProfileComputePlan]:[0 1].\n" + "\t [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" + "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" "\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" @@ -174,39 +183,6 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, return true; } -static bool ParseSessionConfigs(const std::string& configs_string, - std::unordered_map& session_configs) { - std::istringstream ss(configs_string); - std::string token; - - while (ss >> token) { - if (token == "") { - continue; - } - - std::string_view token_sv(token); - - auto pos = token_sv.find("|"); - if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { - // Error: must use a '|' to separate the key and value for session configuration entries. - return false; - } - - std::string key(token_sv.substr(0, pos)); - std::string value(token_sv.substr(pos + 1)); - - auto it = session_configs.find(key); - if (it != session_configs.end()) { - // Error: specified duplicate session configuration entry: {key} - return false; - } - - session_configs.insert(std::make_pair(std::move(key), std::move(value))); - } - - return true; -} - /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) { @@ -381,7 +357,13 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.run_config.intra_op_thread_affinities = ToUTF8String(optarg); break; case 'C': { - if (!ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries)) { + ORT_TRY { + ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "Error parsing session configuration entries: %s\n", ex.what()); + }); return false; } break; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index eb230ac771e13..a96028ed3903e 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -17,6 +17,11 @@ #include #include "providers.h" #include "TestCase.h" +#include "strings_helper.h" + +#ifdef USE_OPENVINO +#include "nlohmann/json.hpp" +#endif #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" @@ -39,13 +44,8 @@ std::chrono::duration OnnxRuntimeTestSession::Run() { auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); - if (!use_device_mem) { - auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), - output_names_raw_ptr.data(), output_names_raw_ptr.size()); - } else { - session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), - output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size()); - } + session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size()); auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration_seconds = end - start; @@ -59,6 +59,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device Ort::SessionOptions session_options; provider_name_ = performance_test_config.machine_config.provider_type_name; + std::unordered_map provider_options; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options @@ -73,24 +74,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif // defined(_MSC_VER) int num_threads = 0; - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [OneDNN] Use a '|' to separate the key and value for the " - "run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - - if (key == "num_of_threads") { - std::stringstream sstream(value); + ParseSessionConfigs(ov_string, provider_options, {"num_of_threads"}); + for (const auto& provider_option : provider_options) { + if (provider_option.first == "num_of_threads") { + std::stringstream sstream(provider_option.second); sstream >> num_threads; if (num_threads < 0) { ORT_THROW( @@ -98,10 +85,6 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device " set number of threads or use '0' for default\n"); // If the user doesnt define num_threads, auto detect threads later } - } else { - ORT_THROW( - "[ERROR] [OneDNN] wrong key type entered. " - "Choose from the following runtime key options that are available for OneDNN. ['num_of_threads']\n"); } } dnnl_options.threadpool_args = static_cast(&num_threads); @@ -145,22 +128,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - buffer.emplace_back(token.substr(0, pos)); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(token.substr(pos + 1)); - option_values.push_back(buffer.back().c_str()); + ParseSessionConfigs(ov_string, provider_options); + for (const auto& provider_option : provider_options) { + option_keys.push_back(provider_option.first.c_str()); + option_values.push_back(provider_option.second.c_str()); } Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, @@ -193,24 +164,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - buffer.emplace_back(token.substr(0, pos)); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(token.substr(pos + 1)); - option_values.push_back(buffer.back().c_str()); + ParseSessionConfigs(ov_string, provider_options); + for (const auto& provider_option : provider_options) { + option_keys.push_back(provider_option.first.c_str()); + option_values.push_back(provider_option.second.c_str()); } - Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, option_keys.data(), option_values.data(), option_keys.size())); if (!status.IsOK()) { @@ -240,22 +198,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map qnn_options; - - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use."); - } - - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - + ParseSessionConfigs(option_string, provider_options, + {"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency", + "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path", + "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", + "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer"}); + for (const auto& provider_option : provider_options) { + const std::string& key = provider_option.first; + const std::string& value = provider_option.second; if (key == "backend_path" || key == "profiling_file_path") { if (value.empty()) { ORT_THROW("Please provide the valid file path."); @@ -303,25 +253,18 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for " + key + ". select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } - } else { - ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', -'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); } - - qnn_options[key] = value; } - session_options.AppendExecutionProvider("QNN", qnn_options); + session_options.AppendExecutionProvider("QNN", provider_options); #else ORT_THROW("QNN is not supported in this build\n"); #endif @@ -332,22 +275,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map snpe_options; - - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - + ParseSessionConfigs(option_string, provider_options, {"runtime", "priority", "buffer_type", "enable_init_cache"}); + for (const auto& provider_option : provider_options) { if (key == "runtime") { std::set supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"}; if (supported_runtime.find(value) == supported_runtime.end()) { @@ -366,14 +295,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); if (value != "1") { ORT_THROW("Set to 1 to enable_init_cache."); } - } else { - ORT_THROW("Wrong key type entered. Choose from options: ['runtime', 'priority', 'buffer_type', 'enable_init_cache'] \n"); } - - snpe_options[key] = value; } - session_options.AppendExecutionProvider("SNPE", snpe_options); + session_options.AppendExecutionProvider("SNPE", provider_options); #else ORT_THROW("SNPE is not supported in this build\n"); #endif @@ -417,24 +342,43 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef __APPLE__ #ifdef USE_COREML - uint32_t coreml_flags = 0; std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; - std::istringstream ss(ov_string); - - std::string key; - while (ss >> key) { - if (key == "COREML_FLAG_CREATE_MLPROGRAM") { - coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; - std::cout << "Enabling ML Program.\n"; - } else if (key.empty()) { + static const std::unordered_set available_keys = {kCoremlProviderOption_MLComputeUnits, + kCoremlProviderOption_ModelFormat, + kCoremlProviderOption_RequireStaticInputShapes, + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU}; + ParseSessionConfigs(ov_string, provider_options, available_keys); + + std::unordered_map available_options = { + {"CPUAndNeuralEngine", "1"}, + {"CPUAndGPU", "1"}, + {"CPUOnly", "1"}, + {"ALL", "1"}, + }; + for (const auto& provider_option : provider_options) { + if (provider_option.first == kCoremlProviderOption_MLComputeUnits && + available_options.find(provider_option.second) != available_options.end()) { + } else if (provider_option.first == kCoremlProviderOption_ModelFormat && + (provider_option.second == "MLProgram" || provider_option.second == "NeuralNetwork")) { + } else if (provider_option.first == kCoremlProviderOption_RequireStaticInputShapes && + (provider_option.second == "1" || provider_option.second == "0")) { + } else if (provider_option.first == kCoremlProviderOption_EnableOnSubgraphs && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_SpecializationStrategy && + (provider_option.second == "Default" || provider_option.second == "FastPrediction")) { + } else if (provider_option.first == kCoremlProviderOption_ProfileComputePlan && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU && + (provider_option.second == "0" || provider_option.second == "1")) { } else { - ORT_THROW( - "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options " - "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n"); + ORT_THROW("Invalid value for option ", provider_option.first, ": ", provider_option.second); } } // COREML_FLAG_CREATE_MLPROGRAM - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags)); + session_options.AppendExecutionProvider("CoreML", provider_options); #else ORT_THROW("CoreML is not supported in this build\n"); #endif @@ -443,34 +387,20 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - std::unordered_map dml_options; - dml_options["performance_preference"] = "high_performance"; - dml_options["device_filter"] = "gpu"; - dml_options["disable_metacommands"] = "false"; - dml_options["enable_graph_capture"] = "false"; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - + ParseSessionConfigs(ov_string, provider_options, + {"device_filter", "performance_preference", "disable_metacommands", + "enable_graph_capture", "enable_graph_serialization"}); + for (const auto& provider_option : provider_options) { + const std::string& key = provider_option.first; + const std::string& value = provider_option.second; if (key == "device_filter") { std::set ov_supported_device_types = {"gpu", "npu"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. " @@ -479,7 +409,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "performance_preference") { std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " @@ -488,7 +417,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "disable_metacommands") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong value for the key 'disable_metacommands'. " @@ -497,7 +425,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "enable_graph_capture") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong value for the key 'enable_graph_capture'. " @@ -514,7 +441,19 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } } - session_options.AppendExecutionProvider("DML", dml_options); + if (provider_options.find("performance_preference") == provider_options.end()) { + provider_options["performance_preference"] = "high_performance"; + } + if (provider_options.find("device_filter") == provider_options.end()) { + provider_options["device_filter"] = "gpu"; + } + if (provider_options.find("disable_metacommands") == provider_options.end()) { + provider_options["disable_metacommands"] = "false"; + } + if (provider_options.find("enable_graph_capture") == provider_options.end()) { + provider_options["enable_graph_capture"] = "false"; + } + session_options.AppendExecutionProvider("DML", provider_options); #else ORT_THROW("DML is not supported in this build\n"); #endif @@ -525,21 +464,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif // defined(_MSC_VER) - std::istringstream ss(ov_string); - std::string token; bool enable_fast_math = false; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - + ParseSessionConfigs(ov_string, provider_options, {"enable_fast_math"}); + for (const auto& provider_option : provider_options) { if (key == "enable_fast_math") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { @@ -549,9 +476,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. " "Select from 'true' or 'false' \n"); } - } else { - ORT_THROW( - "[ERROR] [ACL] Unrecognized option: ", key); } } Ort::ThrowOnError( @@ -596,8 +520,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kWebGpuExecutionProvider) { #ifdef USE_WEBGPU - session_options.AppendExecutionProvider( - "WebGPU", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); + session_options.AppendExecutionProvider("WebGPU", {}); #else ORT_THROW("WebGPU is not supported in this build\n"); #endif @@ -608,24 +531,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map vitisai_session_options; + ParseSessionConfigs(option_string, provider_options); - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [VitisAI] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - vitisai_session_options[key] = value; - } - session_options.AppendExecutionProvider_VitisAI(vitisai_session_options); + session_options.AppendExecutionProvider_VitisAI(provider_options); #else ORT_THROW("VitisAI is not supported in this build\n"); #endif @@ -807,13 +715,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. CPU only supports FP32 . \n"); } } - } else if (key == "enable_npu_fast_compile") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); - } } else if (key == "enable_opencl_throttling") { if (value == "true" || value == "True" || value == "false" || value == "False") { @@ -843,6 +744,28 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else { ov_options[key] = value; } + } else if (key == "load_config") { + auto load_json = [&](std::string filename) -> std::string { + std::ifstream input_filestream(filename); + if (!input_filestream.is_open()) { + ORT_THROW("Passed an invalid JSON config file path \"" + filename + "\"."); + } + nlohmann::json json_config; + try { + input_filestream >> json_config; + } catch (const OnnxRuntimeException& ex) { + ORT_THROW("Exception parsing config file \"" + filename + "\".\n" + ex.what()); + } catch (const std::exception& ex) { + throw std::runtime_error("Standard exception for config file \"" + filename + "\".\n" + ex.what()); + } catch (...) { + throw std::runtime_error("Unknown exception for config file \"" + filename + "\".\n"); + } + if (json_config.empty()) { + ORT_THROW("Empty JSON content passed \"" + filename + "\"."); + } + return json_config.dump(); + }; + ov_options[key] = load_json(value); } else if (key == "model_priority") { ov_options[key] = value; } else if (key == "cache_dir") { @@ -855,21 +778,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else { ov_options[key] = value; } - } else if (key == "export_ep_ctx_blob") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW( - "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " - "should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "use_device_mem") { - if (value == "true" || value == "True") { - use_device_mem = true; - } + } else if (key == "device_memory_name") { + device_memory_name_ = std::move(value); } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); + ORT_THROW( + "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." + " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " + "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); @@ -912,25 +827,31 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_[i] = input_names_str_[i].c_str(); } - if (use_device_mem) { - Ort::MemoryInfo memory_info = Ort::MemoryInfo("OpenVINO_RT_NPU", OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + auto transform_fcn = std::function(); + auto new_value = std::function&, Ort::ConstTensorTypeAndShapeInfo&)>(); + if (device_memory_name_.empty()) { + transform_fcn = [](int64_t input) { return input; }; + new_value = [](OrtAllocator*, const std::vector&, Ort::ConstTensorTypeAndShapeInfo&) { + return Ort::Value(nullptr); + }; + } else { + Ort::MemoryInfo memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); custom_allocator_ = std::make_unique(session_, memory_info); - for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { - Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - - std::vector output_shape = tensor_info.GetShape(); + allocator_ = *custom_allocator_; - // free dimensions are treated as 1 if not overridden - for (int64_t& dim : output_shape) { - if (dim == -1) { - dim = 1; - } - } + // free dimensions are treated as 1 if not overridden + transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; + new_value = [](OrtAllocator* allocator, const std::vector& output_shape, Ort::ConstTensorTypeAndShapeInfo& tensor_info) { + return Ort::Value::CreateTensor(allocator, output_shape.data(), output_shape.size(), tensor_info.GetElementType()); + }; + } - outputs_.push_back(Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)output_shape.data(), - output_shape.size(), tensor_info.GetElementType())); - } + for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { + Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + std::transform(output_shape.begin(), output_shape.end(), output_shape.begin(), transform_fcn); + outputs_.emplace_back(new_value(allocator_, output_shape, tensor_info)); } } @@ -1020,29 +941,16 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - if (!use_device_mem) { - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - } std::vector input_node_dim = tensor_info.GetShape(); // free dimensions are treated as 1 if not overridden - for (int64_t& dim : input_node_dim) { - if (dim == -1) { - dim = 1; - } - } - if (use_device_mem) { - Ort::Value input_tensor = Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); - } else { - auto allocator = Ort::AllocatorWithDefaultOptions(); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); - } + auto transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; + std::transform(input_node_dim.begin(), input_node_dim.end(), input_node_dim.begin(), transform_fcn); + + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); } } return true; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index e33041a2a0958..7d5e46983ad41 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -38,6 +38,7 @@ class OnnxRuntimeTestSession : public TestSession { std::mt19937 rand_engine_; std::uniform_int_distribution dist_; std::vector> test_inputs_; + OrtAllocator* allocator_ = Ort::AllocatorWithDefaultOptions(); std::unique_ptr custom_allocator_; std::vector outputs_; std::vector output_names_; @@ -48,7 +49,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_str_; const int input_length_; std::string provider_name_; - bool use_device_mem = false; + std::string device_memory_name_; // Device memory type name to use from the list in allocator.h }; } // namespace perftest diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 08d77008dc25c..faf0c34193717 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -189,8 +189,8 @@ Status PerformanceRunner::RunParallelDuration() { // TODO: Make each thread enqueue a new worker. auto tpool = GetDefaultThreadPool(Env::Default()); std::atomic counter = {0}; - OrtMutex m; - OrtCondVar cv; + std::mutex m; + std::condition_variable cv; auto start = std::chrono::high_resolution_clock::now(); auto end = start; @@ -206,7 +206,7 @@ Status PerformanceRunner::RunParallelDuration() { if (!status.IsOK()) std::cerr << status.ErrorMessage(); // Simplified version of Eigen::Barrier - std::lock_guard lg(m); + std::lock_guard lg(m); counter--; cv.notify_all(); }); @@ -216,7 +216,7 @@ Status PerformanceRunner::RunParallelDuration() { } while (duration_seconds.count() < performance_test_config_.run_config.duration_in_seconds); // Join - std::unique_lock lock(m); + std::unique_lock lock(m); cv.wait(lock, [&counter]() { return counter == 0; }); return Status::OK(); @@ -228,8 +228,8 @@ Status PerformanceRunner::ForkJoinRepeat() { // create a threadpool with one thread per concurrent request auto tpool = std::make_unique(run_config.concurrent_session_runs); std::atomic counter{0}, requests{0}; - OrtMutex m; - OrtCondVar cv; + std::mutex m; + std::condition_variable cv; // Fork for (size_t i = 0; i != run_config.concurrent_session_runs; ++i) { @@ -242,14 +242,14 @@ Status PerformanceRunner::ForkJoinRepeat() { } // Simplified version of Eigen::Barrier - std::lock_guard lg(m); + std::lock_guard lg(m); counter--; cv.notify_all(); }); } // Join - std::unique_lock lock(m); + std::unique_lock lock(m); cv.wait(lock, [&counter]() { return counter == 0; }); return Status::OK(); diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h index cb1cb661550a7..b0a0161e7fd6c 100644 --- a/onnxruntime/test/perftest/performance_runner.h +++ b/onnxruntime/test/perftest/performance_runner.h @@ -14,7 +14,7 @@ #include #include #include -#include +#include #include #include "test_configuration.h" #include "heap_buffer.h" @@ -75,7 +75,7 @@ class PerformanceRunner { ORT_RETURN_IF_ERROR(status); if (!isWarmup) { - std::lock_guard guard(results_mutex_); + std::lock_guard guard(results_mutex_); performance_result_.time_costs.emplace_back(duration_seconds.count()); performance_result_.total_time_cost += duration_seconds.count(); if (performance_test_config_.run_config.f_verbose) { @@ -116,7 +116,7 @@ class PerformanceRunner { onnxruntime::test::HeapBuffer b_; std::unique_ptr test_case_; - OrtMutex results_mutex_; + std::mutex results_mutex_; }; } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc new file mode 100644 index 0000000000000..e09c8fac70887 --- /dev/null +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. + +#include +#include + +#include "strings_helper.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace perftest { + +void ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs, + const std::unordered_set& available_keys) { + std::istringstream ss(configs_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + std::string key(token_sv.substr(0, pos)); + std::string value(token_sv.substr(pos + 1)); + + if (available_keys.empty() == false && available_keys.count(key) == 0) { + // Error: unknown option: {key} + std::string available_keys_str; + for (const auto& av_key : available_keys) { + available_keys_str += av_key; + available_keys_str += ", "; + } + ORT_THROW("[ERROR] wrong key type entered : `", key, + "`. The following runtime key options are avaible: [", available_keys_str, "]"); + } + + auto it = session_configs.find(key); + if (it != session_configs.end()) { + // Error: specified duplicate session configuration entry: {key} + ORT_THROW("Specified duplicate session configuration entry: ", key); + } + + session_configs.insert(std::make_pair(std::move(key), std::move(value))); + } +} +} // namespace perftest +} // namespace onnxruntime diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h new file mode 100644 index 0000000000000..0d6c56709fde6 --- /dev/null +++ b/onnxruntime/test/perftest/strings_helper.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#include +#include +#include + +namespace onnxruntime { +namespace perftest { + +void ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs, + const std::unordered_set& available_keys = {}); +} // namespace perftest +} // namespace onnxruntime diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 32b4b32e299d6..fa95c1fc52b94 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = #if COREML_EP_AVAILABLE if (useCoreML) { - const uint32_t flags = COREML_FLAG_USE_CPU_ONLY; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags)); + std::unordered_map provider_options = { + {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}}; + session_options.AppendExecutionProvider("CoreML", provider_options); } #else (void)useCoreML; diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index 86001b6cb50a5..b53a4a2df09b4 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = #if COREML_EP_AVAILABLE if (useCoreML) { - const uint32_t flags = COREML_FLAG_USE_CPU_ONLY; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags)); + std::unordered_map provider_options = { + {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}}; + session_options.AppendExecutionProvider("CoreML", provider_options); } #else (void)useCoreML; diff --git a/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py b/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py new file mode 100644 index 0000000000000..4e5329dd5b09a --- /dev/null +++ b/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py @@ -0,0 +1,54 @@ +import argparse + +plist_file_content = """ + + + + + method + development + teamID + {team_id} + provisioningProfiles + + ai.onnxruntime.tests.ios-package-test + {provisioning_profile_uuid} + + signingStyle + manual + + +""" +if __name__ == "__main__": + # handle cli args + parser = argparse.ArgumentParser( + "Generates a PList file to the relevant destination. This PList file contains the properties to allow a user to generate an IPA file for the ios-package-test. " + ) + + parser.add_argument("--dest_file", type=str, help="Path to output the PList file to.", required=True) + parser.add_argument( + "--apple_team_id", + type=str, + help="The Team ID associated with the provisioning profile. You should be able to find this from the Apple developer portal under Membership.", + required=True, + ) + parser.add_argument( + "--provisioning_profile_uuid", + type=str, + help="The Provisioning Profile UUID, which can be found in the .mobileprovision file. ", + required=True, + ) + + args = parser.parse_args() + + formatted_plist = plist_file_content.format( + team_id=args.apple_team_id, provisioning_profile_uuid=args.provisioning_profile_uuid + ) + + with open(args.dest_file, "w") as file: + file.write(formatted_plist) + + print("Wrote plist file to ", args.dest_file) + print() + print("Contents of file:") + print(formatted_plist) diff --git a/onnxruntime/test/platform/threadpool_test.cc b/onnxruntime/test/platform/threadpool_test.cc index 9b3eac1088a47..e0e6c0603c784 100644 --- a/onnxruntime/test/platform/threadpool_test.cc +++ b/onnxruntime/test/platform/threadpool_test.cc @@ -3,7 +3,7 @@ #include "core/platform/threadpool.h" #include "core/platform/EigenNonBlockingThreadPool.h" -#include "core/platform/ort_mutex.h" +#include #include "core/util/thread_utils.h" #ifdef _WIN32 #include "test/platform/windows/env.h" @@ -27,7 +27,7 @@ struct TestData { explicit TestData(int num) : data(num, 0) { } std::vector data; - onnxruntime::OrtMutex mutex; + std::mutex mutex; }; // This unittest tests ThreadPool function by counting the number of calls to function with each index. @@ -38,7 +38,7 @@ std::unique_ptr CreateTestData(int num) { } void IncrementElement(TestData& test_data, ptrdiff_t i) { - std::lock_guard lock(test_data.mutex); + std::lock_guard lock(test_data.mutex); test_data.data[i]++; } diff --git a/onnxruntime/test/platform/windows/stacktrace_test.cc b/onnxruntime/test/platform/windows/stacktrace_test.cc index de09dbcf270a9..9b1840f4b5d65 100644 --- a/onnxruntime/test/platform/windows/stacktrace_test.cc +++ b/onnxruntime/test/platform/windows/stacktrace_test.cc @@ -14,7 +14,6 @@ namespace onnxruntime { namespace test { using namespace ::testing; -// TVM is not working with StackTrace now. #if !defined(ORT_NO_EXCEPTIONS) TEST(StacktraceTests, BasicTests) { auto result = ::onnxruntime::GetStackTrace(); diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index dea39bc99d3e9..aa68f68f3e735 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -420,6 +420,7 @@ bool SetEpsForAllNodes(Graph& graph, continue; bool found = false; + const auto& logger = DefaultLoggingManager().DefaultLogger(); for (const auto& ep : execution_providers) { auto provider_type = ep->Type(); @@ -438,7 +439,8 @@ bool SetEpsForAllNodes(Graph& graph, } // Check the EP has an impl for the node from builtin registry. - if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver)) { + if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver, + logger)) { found = true; break; } @@ -451,6 +453,7 @@ bool SetEpsForAllNodes(Graph& graph, std::string_view(kMSInternalNHWCDomain), node.SinceVersion(), type_constraint_map, + logger, &kci); if (status.IsOK() && kci != nullptr) { found = true; @@ -463,7 +466,7 @@ bool SetEpsForAllNodes(Graph& graph, std::any_of(custom_registries->cbegin(), custom_registries->cend(), [&](auto reg) { return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(), - kernel_type_str_resolver); + kernel_type_str_resolver, logger); })) { found = true; break; diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index daa24db134114..a8480e7416de5 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -4,7 +4,7 @@ #include "core/common/logging/logging.h" #include "core/graph/graph.h" #include "core/graph/graph_viewer.h" -#include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory_creator.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" @@ -30,11 +30,11 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -// We want to run UT on CPU only to get output value without losing precision to pass the verification -static constexpr uint32_t s_coreml_flags = COREML_FLAG_USE_CPU_ONLY; - -static std::unique_ptr MakeCoreMLExecutionProvider(uint32_t flags = s_coreml_flags) { - return std::make_unique(flags); +static std::unique_ptr MakeCoreMLExecutionProvider( + std::string ModelFormat = "NeuralNetwork", std::string ComputeUnits = "CPUOnly") { + std::unordered_map provider_options = {{kCoremlProviderOption_MLComputeUnits, ComputeUnits}, + {kCoremlProviderOption_ModelFormat, ModelFormat}}; + return CoreMLProviderFactoryCreator::Create(provider_options)->CreateProvider(); } #if !defined(ORT_MINIMAL_BUILD) @@ -127,6 +127,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); #endif @@ -164,6 +168,11 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif diff --git a/onnxruntime/test/providers/coreml/dynamic_input_test.cc b/onnxruntime/test/providers/coreml/dynamic_input_test.cc index c91ef23650040..8294f65745256 100644 --- a/onnxruntime/test/providers/coreml/dynamic_input_test.cc +++ b/onnxruntime/test/providers/coreml/dynamic_input_test.cc @@ -7,6 +7,7 @@ #include #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory_creator.h" #include "core/providers/coreml/coreml_provider_factory.h" // for COREMLFlags #include "test/common/random_generator.h" #include "test/providers/model_tester.h" @@ -20,8 +21,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MatMul) { auto test = [&](const size_t M) { SCOPED_TRACE(MakeString("M=", M)); - - auto coreml_ep = std::make_unique(0); + std::unordered_map options; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); const auto ep_verification_params = EPVerificationParams{ ExpectedEPNodeAssignment::All, @@ -54,8 +55,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) { auto test = [&](const size_t batch_size) { SCOPED_TRACE(MakeString("batch_size=", batch_size)); - - auto coreml_ep = std::make_unique(0); + std::unordered_map options; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); const auto ep_verification_params = EPVerificationParams{ ExpectedEPNodeAssignment::All, @@ -87,6 +88,7 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) { constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx"); ModelTester tester(CurrentTestName(), model_path); + std::unordered_map options; tester.AddInput("A", {0, 2}, {}); tester.AddOutput("Y", {0, 4}, {}); @@ -94,14 +96,15 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) { tester .Config(ModelTester::ExpectResult::kExpectFailure, "the runtime shape ({0,2}) has zero elements. This is not supported by the CoreML EP.") - .ConfigEp(std::make_unique(0)) + .ConfigEp(CoreMLProviderFactoryCreator::Create(options)->CreateProvider()) .RunWithConfig(); } TEST(CoreMLExecutionProviderDynamicInputShapeTest, OnlyAllowStaticInputShapes) { constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx"); - - auto coreml_ep = std::make_unique(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES); + std::unordered_map options = {{kCoremlProviderOption_RequireStaticInputShapes, "1"}}; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); + ; TestModelLoad(model_path, std::move(coreml_ep), // expect no supported nodes because we disable dynamic input shape support diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 8ca0f6d845a09..59813f433dc41 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -105,7 +105,12 @@ class ActivationOpTest : public ::testing::Test { std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dist(low, high); +#ifdef COREML_ENABLE_MLPROGRAM + // please check onnxruntime/onnxruntime/core/providers/coreml/builders/helper.cc:81 + std::vector batch_size_list = {1, 2, 4, 9, 100}; +#else std::vector batch_size_list = {1, 2, 4, 9, 100000}; +#endif for (auto batch_size : batch_size_list) { std::vector vec(batch_size); for (size_t i = 0; i != batch_size; ++i) { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index b2e9034653746..a74517840097c 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -414,6 +414,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) { #endif } +TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) { + OpTester test("Add"); + + test.AddInput("A", {2, 2, 2}, + {101.0f, 102.0f, + 103.0f, 104.0f, + + 201.0f, 202.0f, + 203.0f, 204.0f}); + test.AddInput("B", {1, 2, 2}, + {010.0f, 020.0f, + 030.0f, 040.0f}); + test.AddOutput("C", {2, 2, 2}, + {111.0f, 122.0f, + 133.0f, 144.0f, + + 211.0f, 222.0f, + 233.0f, 244.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) { OpTester test("Add"); @@ -2249,6 +2271,21 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_12_MLFloat16_Scalar2) { + OpTester test("Max", 12); + test.AddInput("data_0", {1}, + MakeMLFloat16({-1.f})); + test.AddInput("data_1", {}, + MakeMLFloat16({2.f})); + test.AddInput("data_2", {1, 3}, + MakeMLFloat16({-2.f, -3.f, -4.f})); + test.AddInput("data_3", {1, 1, 3}, + MakeMLFloat16({-2.f, -3.f, -4.f})); + test.AddOutput("max", {1, 1, 3}, + MakeMLFloat16({2.f, 2.f, 2.f})); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent +} + TEST(MathOpTest, Max_13_Float16_MatrixVector) { TestFloat16MinMax("Max", {4, 3}, @@ -3181,7 +3218,14 @@ TEST(MathOpTest, Tan) { TEST(MathOpTest, Asin) { OpTester test("Asin"); - float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f; + float abs_error = +#ifdef _WIN32 + // Set abs_error to 0.0001f for built-in function asin() in HLSL based EPs (DML and WebGPU) + DefaultDmlExecutionProvider().get() != nullptr || DefaultWebGpuExecutionProvider().get() != nullptr + ? 0.0001f + : +#endif + -1.0f; TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error); } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 66408e6adfbc5..d0069a0069646 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -25,7 +25,7 @@ const constexpr auto run_with_tunable_op = &run_options; } // namespace -// Only CUDA, ROCM and CoreML kernels have float 16 support +// Only CUDA, ROCM, CoreML and XNNPack kernels have float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index a7d2281ac19f8..298e870f348fc 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" + #include "test/providers/provider_test_utils.h" #include "test/providers/run_options_config_keys.h" #include "test/common/dnnl_op_test_utils.h" @@ -37,109 +38,125 @@ template std::vector> GenerateTestCases() { std::vector> test_cases; + auto real_expected_vals = [](const std::vector& expected_vals) { + if constexpr (std::is_same_v) { + return expected_vals; + } else if constexpr (std::is_same_v) { + std::vector expected_vals_fp16(expected_vals.size()); + std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(), + [](int32_t num) { return MLFloat16(float(num)); }); + return expected_vals_fp16; + } else { + std::vector real_expected_vals(expected_vals.size()); + std::transform(expected_vals.begin(), expected_vals.end(), real_expected_vals.begin(), + [](int32_t num) { return static_cast(num); }); + return real_expected_vals; + } + }; + test_cases.push_back( {"test padding and broadcast A > B", {3, 1, 1, 2}, {2, 2, 2}, {3, 2, 1, 2}, - {2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}}); + real_expected_vals({2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55})}); test_cases.push_back( {"test padding and broadcast B > A", {2, 3, 2}, {3, 2, 2, 1}, {3, 2, 3, 1}, - {1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}}); + real_expected_vals({1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221})}); test_cases.push_back( {"test left 1D", {2}, {3, 2, 1}, {3, 1}, - {1, 3, 5}}); + real_expected_vals({1, 3, 5})}); test_cases.push_back( {"test right 1D", {3, 1, 2}, {2}, {3, 1}, - {1, 3, 5}}); + real_expected_vals({1, 3, 5})}); test_cases.push_back( {"test left 1D right 2D", {2}, {2, 3}, {3}, - {3, 4, 5}}); + real_expected_vals({3, 4, 5})}); test_cases.push_back( {"test scalar output", {3}, {3}, {}, - {5}}); + real_expected_vals({5})}); test_cases.push_back( {"test 2D", {3, 4}, {4, 3}, {3, 3}, - {42, 48, 54, 114, 136, 158, 186, 224, 262}}); + real_expected_vals({42, 48, 54, 114, 136, 158, 186, 224, 262})}); test_cases.push_back( {"test 2D special", {2, 2, 3}, {3, 4}, {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})}); test_cases.push_back( {"test 2D special 2", {2, 2, 3}, {1, 3, 4}, {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})}); test_cases.push_back( {"test 2D special 3", {2, 6}, {1, 1, 6, 1}, {1, 1, 2, 1}, - {55, 145}}); + real_expected_vals({55, 145})}); test_cases.push_back( {"test 2D empty input", {3, 4}, {4, 0}, {3, 0}, - {}}); + real_expected_vals({})}); test_cases.push_back( {"test 3D batch", {3, 1, 3}, {3, 3, 2}, {3, 1, 2}, - { + real_expected_vals({ // clang-format off 10, 13, 100, 112, 298, 319, // clang-format on - }}); + })}); test_cases.push_back( {"test 4D batch", {2, 2, 1, 3}, {2, 2, 3, 2}, {2, 2, 1, 2}, - { + real_expected_vals({ // clang-format off 10, 13, 100, 112, 298, 319, 604, 634, // clang-format on - }}); + })}); return test_cases; } @@ -189,18 +206,31 @@ TEST(MathOpTest, MatMulFloatType) { GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; } RunMatMulTest(7, false, false); + // Note. Xnnpack only supports matmul when Matrix B is constant + RunMatMulTest(7, false, true); } -TEST(MathOpTest, MatMulDoubleType) { - RunMatMulTest(7); -} - -TEST(MathOpTest, MatMulFloatTypeInitializer) { +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) +TEST(MathOpTest, MatMulFloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; } - RunMatMulTest(7, false, true); + RunMatMulTest(14, false, false); + // Note. Xnnpack only supports matmul when Matrix B is constant + RunMatMulTest(14, false, true); +} +#endif + +TEST(MathOpTest, MatMulDoubleType) { + RunMatMulTest(7); } TEST(MathOpTest, MatMulInt32Type) { @@ -246,7 +276,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -255,8 +285,6 @@ TEST(MathOpTest, MatMul_Float16) { return; } #endif - OpTester test("MatMul", 14); - std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; std::vector B(12, 1.0f); @@ -270,12 +298,18 @@ TEST(MathOpTest, MatMul_Float16) { ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); - test.AddInput("A", {2, 4}, f_A); - test.AddInput("B", {4, 3}, f_B); - test.AddOutput("Y", {2, 3}, f_Y); - test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported - .Config(run_with_tunable_op) - .RunWithConfig(); + auto run_test = [&](bool B_is_constant) { + // it needs Matrix B as constant to test XNNPack + OpTester test("MatMul", 14); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, B_is_constant); + test.AddOutput("Y", {2, 3}, f_Y); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + }; + run_test(true); + run_test(false); } #endif diff --git a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc index 8cdb837712e83..096263792727a 100644 --- a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc @@ -126,8 +126,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul3D_S8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {208, 236, 0, 238, 3, 214, 255, 29}); @@ -155,10 +155,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { @@ -197,8 +199,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {80, -2, -128, 110, -125, 86, 127, -99}); @@ -225,10 +227,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { test.Run(); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } static void QLinearMatMul2DTest(bool only_t1_not_initializer) { diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 6eb72255bdf9a..6f7930f722564 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -49,7 +49,7 @@ TEST(SoftmaxOperator, Simple) { RunTest(x_vals, expected_vals, dimensions); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_XNNPACK) TEST(SoftmaxOperator, Simple_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc new file mode 100644 index 0000000000000..49bb0ae65d1c9 --- /dev/null +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_double_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_float_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + array_as_tensor.add_dims(array.size()); + for (const auto v : array) { + array_as_tensor.add_int32_data(v); + } + + return array_as_tensor; +} + +template +void _multiply_update_array(std::vector& data, int n, T inc = 0) { + std::vector copy = data; + data.resize(copy.size() * n); + T cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j] + cst; + } + cst += inc; + } +} + +template +void _multiply_update_childnode(std::vector& childnodes, std::vector& childleafs, std::vector& otherchildleafs, int n) { + int64_t leafs_cnt = 0; + int64_t nodes_cnt = childnodes.size(); + for (auto& childleaf : childleafs) { + if (childleaf) { + leafs_cnt++; + } + } + for (auto& childleaf : otherchildleafs) { + if (childleaf) { + leafs_cnt++; + } + } + + std::vector copy = childnodes; + childnodes.resize(copy.size() * n); + T leafs_cst = 0; + T nodes_cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + T curr_inc = childleafs[j] ? leafs_cst : nodes_cst; + childnodes[j + i * copy.size()] = copy[j] + curr_inc; + } + + leafs_cst += leafs_cnt; + nodes_cst += nodes_cnt; + } +} + +template +void _multiply_arrays_values(std::vector& data, int64_t val) { + for (auto& curr : data) { + curr *= val; + } +} + +template +void GenTreeAndRunTest(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 2; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_modes = {0, 0, 0}; + std::vector nodes_splits = {3.14f, 1.2f, 4.2f}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {0, 1, 1}; + + std::vector leaf_targetids = {0, 1, 0, 1}; + std::vector leaf_weights = {5.23f, 12.12f, -12.23f, 7.21f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 2}, Y); + test.Run(); +} + +template +void GenTreeAndRunTestWithSetMembership(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 4; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {1, 0, 1}; + std::vector leaf_targetids = {0, 1, 2, 3}; + + std::vector nodes_modes = {0, 6, 6}; + std::vector nodes_splits = {11.f, 232344.f, NAN}; + std::vector membership_values = {1.2f, 3.7f, 8.f, 9.f, NAN, 12.f, 7.f, NAN}; + std::vector leaf_weights = {1.f, 10.f, 1000.f, 100.f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(membership_values, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto membership_values_as_tensor = make_tensor(membership_values, "membership_values"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("membership_values", membership_values_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {6, 1}, X); + test.AddOutput("Y", {6, 4}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeEnsembleFloat) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + Y = {15.69f, 0.f, 15.69f, 0.f, 0.f, 36.36f}; + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleDouble) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + _multiply_arrays_values(Y, 3); + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleSetMembership) { + std::vector X = {1.2f, 3.4f, -0.12f, NAN, 12.0f, 7.0f}; + std::vector Y = { + 1.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 10.f, 0.f, 0.f}; + GenTreeAndRunTestWithSetMembership(X, Y, 1, 1); + + _multiply_arrays_values(Y, 5); + GenTreeAndRunTestWithSetMembership(X, Y, 1, 5); +} + +TEST(MLOpTest, TreeEnsembleLeafOnly) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 1; + + int64_t aggregate_function = 1; + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_modes = {0}; + std::vector nodes_featureids = {0}; + std::vector nodes_splits = {0.f}; + std::vector nodes_truenodeids = {0}; + std::vector nodes_trueleafs = {1}; + std::vector nodes_falsenodeids = {0}; + std::vector nodes_falseleafs = {1}; + + std::vector leaf_targetids = {0}; + std::vector leaf_weights = {6.23f}; + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + std::vector X = {1.f, 4.f}; + std::vector Y = {6.23f, 6.23f}; + + test.AddInput("X", {2, 1}, X); + test.AddOutput("Y", {2, 1}, Y); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 33c23b53fb5aa..eaf8fea03eaa0 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) { GenTreeAndRunTest1_as_tensor_precision(3); } +TEST(MLOpTest, TreeRegressorCategoricals) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 0, 0, 1, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 4, 0, 5.5, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0}; + std::vector nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {3, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f}; + std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 1}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeRegressorCategoricalsFolding) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 1, 1, 0, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 2, 3, 0, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0}; + std::vector nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {4, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}; + std::vector Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453}; + test.AddInput("X", {4, 2}, X); + test.AddOutput("Y", {4, 1}, Y); + test.Run(); +} + TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) { OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 177647ab5be6b..e3c86a137484f 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -570,7 +570,7 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("yolov3"), ORT_TSTR("LSTM_Seq_lens_unpacked"), ORT_TSTR("tinyyolov3"), - ORT_TSTR("faster_rcnn"), + // ORT_TSTR("faster_rcnn"), ORT_TSTR("mask_rcnn"), ORT_TSTR("coreml_FNS-Candy_ImageNet"), ORT_TSTR("tf_mobilenet_v2_1.0_224"), @@ -581,7 +581,7 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("mlperf_ssd_resnet34_1200"), ORT_TSTR("candy"), ORT_TSTR("cntk_simple_seg"), - ORT_TSTR("GPT2_LM_HEAD"), + // ORT_TSTR("GPT2_LM_HEAD"), ORT_TSTR("mlperf_ssd_mobilenet_300"), ORT_TSTR("fp16_coreml_FNS-Candy"), ORT_TSTR("fp16_test_tiny_yolov2"), diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index b0d97410ac9b3..08c4e608aada3 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -704,7 +704,7 @@ TEST(BatchNormTest, NonSpatial_Complicated) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(BatchNormTest, BatchNorm2d_fp16) { vector X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f, -0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f, @@ -765,9 +765,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { -0.0989828f, -0.160014f, 0.362077f, 0.0649763f, -0.371465f, 0.727401f, 0.0320011f}; float epsilon = 1e-05f; - OpTester test("BatchNormalization"); - test.AddAttribute("epsilon", epsilon); - vector input_shape{2, 3, 6, 6}; int input_size = 2 * 3 * 6 * 6; @@ -785,13 +782,20 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { ConvertFloatToMLFloat16(var.data(), f_var.data(), 3); ConvertFloatToMLFloat16(expected_output.data(), f_output.data(), input_size); - test.AddInput("X", input_shape, f_X); - test.AddInput("scale", {3}, f_scale); - test.AddInput("B", {3}, f_B); - test.AddInput("mean", {3}, f_mean); - test.AddInput("var", {3}, f_var); - test.AddOutput("output", input_shape, f_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + auto run_test = [&](bool is_initializer) { + OpTester test("BatchNormalization"); + test.AddAttribute("epsilon", epsilon); + test.AddInput("X", input_shape, f_X); + test.AddInput("scale", {3}, f_scale, is_initializer); + test.AddInput("B", {3}, f_B, is_initializer); + test.AddInput("mean", {3}, f_mean, is_initializer); + test.AddInput("var", {3}, f_var, is_initializer); + test.AddOutput("output", input_shape, f_output, is_initializer); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + // coreml EP requires initializer + run_test(true); } #endif diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index ce1ac7591ec34..4253e36e02548 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 25caa732efa25..a3a3dd939cbf0 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include "core/graph/constants.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { @@ -28,7 +29,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, optional epsilon = optional(), OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - int opset = 7) { + int opset = 7, + bool exclude_cuda_nhwc = false) { OpTester test("Conv", opset); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -65,6 +67,12 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); + if (exclude_cuda_nhwc) { +#ifdef ENABLE_CUDA_NHWC_OPS + excluded_providers.insert(kCudaNHWCExecutionProvider); +#endif + } + // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); @@ -197,10 +205,15 @@ TEST(ConvTest, Conv1D_Bias) { // as TF32 has a 10 bit mantissa. float epsilon = 1.1e-5f; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); + // This case is not supported by cuDNN frontend, and the fallback (legacy code) requires weight to 4D tensor for NHWC. + constexpr bool exclude_cuda_nhwc = true; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon, + OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc); // CoreML EP requires weight to be an initializer - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon, + OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc); } // Conv47 diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 29525f89ef544..83b27f10fe04f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/xnnpack/xnnpack_init.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" #include "default_providers.h" using namespace std; @@ -28,6 +30,8 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::vector& expected_output, const vector& expected_output_shape, + float rel_error = 0.0, + float abs_error = 0.0, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -64,7 +68,7 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, for (size_t i = 0; i < inputs.size(); i++) { test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output, false, rel_error, abs_error); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } @@ -78,12 +82,16 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, + float rel_error = 0.0, + float abs_error = 0.0) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, true, expect_result, err_str, extra_exclude_openvino_for_initializer_filter); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, false, expect_result, err_str, excluded_provider_types); } @@ -123,17 +131,6 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape @@ -245,8 +242,22 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; +#ifdef XNNPACK_FP16_SUPPORTED + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, + OpTester::ExpectResult::kExpectSuccess, "", // defalut value + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, // default value + 0.5, 0.5); + } else { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); + } + +#else TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); +#endif } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { diff --git a/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc new file mode 100644 index 0000000000000..ac517193a2c77 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/util/include/default_providers.h" + +#ifdef COREML_ENABLE_MLPROGRAM +using namespace std; +namespace onnxruntime { +namespace test { + +template +class GroupNormalizationOpTest : public ::testing::Test { +}; +using GroupNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GroupNormalizationOpTest, GroupNormalizationOpTestTypes); + +// GroupSize = channel_dims to simulate InstanceNorm +// Disable TensorRT on some of the tests because its parser doesn't support weight as input +TYPED_TEST(GroupNormalizationOpTest, Equivalent_InstanceNorm_G_C) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(3)); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {0.F, 0.F, 0.F}; + vector B_dims = {3}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + // expected output is calculated using torch.nn.GroupNorm(3, 3, eps=0.3) + vector expected_output = {-0.56495477f, 1.48930046f, -1.13334329f, 0.20899761f, + 1.46688162f, -0.98600774f, -0.79911913f, 0.31824524f, + 0.57370438f, 0.42193634f, 0.6525492f, -1.64818992f, + + -0.92380346f, -0.60808484f, 0.04711878f, 1.48476953f, + -0.14644464f, -0.82262872f, -0.66852817f, 1.63760153f, + -1.65898662f, 0.27618144f, 0.64840618f, 0.734399f}; + + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GroupSize = 1 to simulate LayerNorm, (LayerNorm) +// expected output is calculated using torch.nn.GroupNorm(1, 3, eps=1e-5f) +TYPED_TEST(GroupNormalizationOpTest, Equivalent_LayerNorm_G_1) { + auto run_test = [](bool is_initializer) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 1e-5f); + test.AddAttribute("num_groups", int64_t(1)); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("scale", {2}, GetTypedArray({1.0f, 1.0f}), is_initializer); + test.AddInput("bias", {2}, GetTypedArray({2.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({0.5361f, 1.1216f, 1.7072f, 1.2928f, 1.8783f, 2.4638f})); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + }; + + run_test(true); +} + +// expected output is calculated using torch.nn.GroupNorm(2, 6, eps=0.3) +TYPED_TEST(GroupNormalizationOpTest, GroupSize_N) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(2)); + + vector input = {-1.1258f, -1.1524f, -0.2506f, -0.4339f, + 0.8487f, 0.6920f, -0.3160f, -2.1152f, + 0.3223f, -1.2633f, 0.3500f, 0.3081f, + 0.1198f, 1.2377f, 1.1168f, -0.2473f, + -1.3527f, -1.6959f, 0.5667f, 0.7935f, + 0.5988f, -1.5551f, -0.3414f, 1.8530f, + + 0.7502f, -0.5855f, -0.1734f, 0.1835f, + 1.3894f, 1.5863f, 0.9463f, -0.8437f, + -0.6136f, 0.0316f, -0.4927f, 0.2484f, + 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, + 0.5229f, 2.3022f, -1.4689f, -1.5867f}; + vector input_dims = {2, 6, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F, 1.F, 1.F, 1.F}; + vector scale_dims = {6}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {.0F, .0F, .0F, .0F, .0F, .0F}; + vector B_dims = {6}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + vector expected_output = { + -0.7590f, -0.7848f, 0.0914f, -0.0867f, + 1.1595f, 1.0073f, 0.0278f, -1.7203f, + 0.6480f, -0.8926f, 0.6749f, 0.6343f, + 0.0232f, 0.9274f, 0.8296f, -0.2738f, + -1.1679f, -1.4456f, 0.3846f, 0.5681f, + 0.4107f, -1.3317f, -0.3499f, 1.4252f, + + 0.5772f, -0.8298f, -0.3957f, -0.0198f, + 1.2505f, 1.4580f, 0.7838f, -1.1017f, + -0.8594f, -0.1798f, -0.7320f, 0.0486f, + 0.2541f, -0.0377f, 0.4334f, 0.2554f, + -0.2291f, 0.5686f, -0.3962f, -0.0911f, + 0.3282f, 1.9145f, -1.4475f, -1.5525f}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + if constexpr (std::is_same::value) { + test.SetOutputTolerance(1e-4f); + } else { + test.SetOutputTolerance(0.005f); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime +#endif diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 31f119ec6b0e9..341bb8a4fc957 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -3,71 +3,87 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { -// Disable TensorRT on some of the tests because its parser doesn't support weight as input +template +class InstanceNormalizationOpTest : public ::testing::Test { +}; +using InstanceNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(InstanceNormalizationOpTest, InstanceNormalizationOpTestTypes); -TEST(InstanceNormalizationOpTest, InstanceNorm) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, - - 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, - 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, - 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; - vector input_dims = {2, 3, 4}; - test.AddInput("input", input_dims, input); - - // vector scale = {2.1F, 0.1F, 1.F}; - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - // vector B = {2.3F, 1.5F, 0.F}; - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, +// Disable TensorRT on some of the tests because its parser doesn't support weight as input - -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, - -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, - -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; - test.AddOutput("Y", input_dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNorm) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + // vector scale = {2.1F, 0.1F, 1.F}; + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + // vector B = {2.3F, 1.5F, 0.F}; + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, + + -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, + -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, + -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } -TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; - vector input_dims = {1, 3, 4}; - test.AddInput("input", input_dims, input); - - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; - test.AddOutput("Y", input_dims, expected_output); - - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {1, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { @@ -105,7 +121,7 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { OpTester test("InstanceNormalization"); diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index 46eb1180f4e7e..d4e0af5011525 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0968bc32e0de4..f68a245d103e1 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include "gtest/gtest.h" #include "test/common/dnnl_op_test_utils.h" @@ -1374,7 +1375,7 @@ TEST(ReductionOpTest, ReduceMax_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceMax_half) { OpTester test("ReduceMax"); test.AddAttribute("axes", std::vector{1, 2}); @@ -2157,7 +2158,7 @@ TEST(ReductionOpTest, ReduceMin_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceMin_half) { OpTester test("ReduceMin"); test.AddAttribute("axes", std::vector{0, 2}); @@ -2355,7 +2356,7 @@ TEST(ReductionOpTest, ReduceSum_int32) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceSumHalfHalf) { OpTester test("ReduceSum"); test.AddAttribute("keepdims", (int64_t)0); @@ -3175,19 +3176,26 @@ TEST(ReductionOpTest, ReduceProd0DTensor) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(ReductionOpTest, ArgMax) { +template +class ReductionOpTest : public ::testing::Test { +}; + +using ReductionOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ReductionOpTest, ReductionOpTestTypes); + +TYPED_TEST(ReductionOpTest, ArgMax) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); test.AddAttribute("keepdims", (int64_t)1); - test.AddInput("data", {3, 2, 2}, - {1.0f, 2.0f, - 3.0f, 4.0f, + test.AddInput("data", {3, 2, 2}, + GetTypedArray({1.0f, 2.0f, + 3.0f, 4.0f, - 5.0f, 6.0f, - 7.0f, 8.0f, + 5.0f, 6.0f, + 7.0f, 8.0f, - 9.0f, 10.0f, - 11.0f, 12.0f}); + 9.0f, 10.0f, + 11.0f, 12.0f})); test.AddOutput("reduced", {3, 1, 2}, {1, 1, 1, 1, @@ -3330,6 +3338,41 @@ TEST(ReductionOpTest, ArgMax_int32_last_index_dups) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMax_float_first_index_random) { + OpTester test("ArgMax", 12); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float max_value = std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with max_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = max_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, ArgMax_int32_neg_axis) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)(-2)); @@ -3648,6 +3691,41 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) { test.Run(); } +TEST(ReductionOpTest, ArgMin_float_first_index_random) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float min_value = -std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with min_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = min_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) { FastReduceKind fast_kind; TensorShapeVector fast_shape, fast_output_shape, fast_axes; @@ -5603,7 +5681,7 @@ TEST(ReductionOpTest, ReduceSum_RK_parallel) { test.AddOutput("reduced", {32}, expected); // CoreML does not provide 1e-5 precision here (it's off by 1e-4) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCoreMLExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess); } TEST(ReductionOpTest, ReduceSum_RK_keepdims) { diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc index 30960e71c577f..de2aceda17f83 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc @@ -13,7 +13,7 @@ using namespace std; namespace onnxruntime { namespace test { -static const std::vector default_activations = {"sigmoid", "tanh"}; +static const std::vector default_activations = {"Sigmoid", "Tanh"}; static void RunGruTest(const std::vector& X_data, const std::vector& W_data, @@ -150,11 +150,6 @@ void DefaultActivationsSimpleWeightsNoBias(std::string direction, } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.4750208f, 0.450166f, 0.4255575f, 0.45016602f, 0.40131235f, 0.35434368f, @@ -173,11 +168,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows) { } TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.6082785f, 0.50623393f, 0.4426924f, 0.5803454f, 0.4527356f, 0.36886263f, @@ -193,11 +183,6 @@ TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows) { } TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBias) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ // forward output for input sequence 0 0.4750208f, 0.450166f, 0.4255575f, @@ -228,11 +213,6 @@ TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBias) { } TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ // forward output for input sequence 0 0.4750208f, 0.450166f, 0.4255575f, @@ -317,11 +297,6 @@ void DefaultActivationsSimpleWeightsWithBias(std::string direction, } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallel) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.16783132f, -0.11754231f, 0.11977843f, 0.2046872f, -0.10372487f, 0.15365849f, @@ -333,11 +308,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallel) { } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.15024948f, -0.11097029f, -0.02121867f, 0.18887489f, -0.09747667f, 0.02093463f, @@ -350,11 +320,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearB } TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.20910699f, -0.18880953f, -0.04005555f, 0.29700265f, -0.15308119f, 0.04537245f, @@ -368,11 +333,6 @@ TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearB // test forward !batch_parallel_ path with linear_before_reset TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.15024948f, -0.11097029f, -0.02121867f, 0.19538902f, -0.19016478f, -0.05644283f}; @@ -384,11 +344,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { // test reverse !batch_parallel_ path with linear_before_reset TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.20910699f, -0.18880953f, -0.04005555f, 0.12252139f, -0.12032216f, -0.05064924f}; @@ -588,13 +543,8 @@ void DeepCpuGruOpTestContext::RunTest(const std::vector& X, } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -611,13 +561,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpBackwardBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "reverse"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -635,13 +580,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpBackwardBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpBidirectionalBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -663,13 +603,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpBidirectionalBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardActivation) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"tanh", "sigmoid"}; + const std::vector activations = {"Tanh", "Sigmoid"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -687,13 +622,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardActivation) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardInitialHiddenState) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -711,13 +641,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardInitialHiddenState) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatch) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -743,13 +668,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatch) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatchLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -775,13 +695,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatchLinearBeforeReset) { } TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLength) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -820,13 +735,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLength) { } TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLengthLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -865,13 +775,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLengthLinearBeforeReset) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeResetB1) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -891,13 +796,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeResetB2) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -916,13 +816,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -949,13 +844,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -987,13 +877,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) { } TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1025,13 +910,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1058,13 +938,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthShorterThanInputSequenceLength) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1092,13 +967,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthShorterThanInputSequenceLength) } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthAllZeros) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1125,13 +995,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthAllZeros) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSingleBatchMultipleHiddenThreads) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations, true, {}, {}, /*large_hidden*/ true); @@ -1160,13 +1025,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSingleBatchMultipleHiddenThreads) { } TEST(GRUTest, ONNXRuntime_TestGRUPositiveActivationClipping) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations, true, {}, {}, /*large_hidden*/ true); diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 421561a5a859b..384adb5916cc1 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -149,11 +149,6 @@ using CastNonStringTypes = MLFloat16, BFloat16>; TEST(CastOpTest, NonStringTypes) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true"; - } - boost::mp11::mp_for_each>( CastNonStringTester{}); } diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 4a1888a5ca7d6..9e0fb81cbb0fc 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -75,17 +76,6 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index 4b0f4e84ca37d..38e3bc3af6d6b 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -167,6 +167,20 @@ TEST(ExpandOpTest, Expand_2x2x1x2x1_float) { test.Run(); } +TEST(ExpandOpTest, Expand_3x1x8_float) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 2, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test.AddInput("data_1", {3}, {3, 1, 8}); + test.AddOutput("result", {3, 2, 8}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, + 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, + 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}); + test.Run(); +} + #ifndef USE_TENSORRT TEST(ExpandOpTest, Expand_scalar_float) { OpTester test("Expand", 8); diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc index 5b2d00bb956bf..81e51375b9992 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc @@ -389,9 +389,10 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) { // skip openvino which will not throw error message but will ensure no out-of-bound access // skip TensorRT because it doesn't support out of bounds indices // skip QNN because it doesn't support out of bounds indices + // skip WebGPU because it doesn't support out of bounds indices test.Run(OpTester::ExpectResult::kExpectFailure, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, - kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider}); + kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(GatherElementsOpTest, BigIndices) { diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index a32d43f296250..2169436255727 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -263,22 +264,6 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - std::vector inputs_T(inputs.size()); - if constexpr (std::is_same::value) { - return inputs; - } else if constexpr (std::is_integral_v) { - for (size_t i = 0; i < inputs.size(); i++) { - inputs_T[i] = static_cast(inputs[i]); - } - return inputs_T; - } else { - ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); - return inputs_T; - } -} - template static void TestSlice1DIntData() { // static_assert(std::is_integral_v); diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 48872404f08bd..1c2a86bb808b5 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "core/framework/to_tensor_proto_element_type.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -178,17 +179,6 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index b517b1a2837f0..5902fbe3ddd6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -142,7 +142,7 @@ void RunTestWrapper() { RunTest({2, 1, 3}, {2, 2, 1}); RunTest({2, 1, 3}, {2, 2, 1}, true); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // _TileMemcpyKernelFromInput, vectorized 4 RunTest({256, 512}, {3, 1}); @@ -253,7 +253,7 @@ TEST(TensorOpTest, TileStringType) { RunTestWrapper(); } TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 3b46dc3f5d6a2..73a5bce768a2a 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -69,17 +69,6 @@ void TransposeTest(const std::vector& input_shape, } } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - // Test 2 dimensional transpose, with no permutation attribute specified TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index d2aa5dd428fec..d1910c89f76b7 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -11,7 +11,7 @@ namespace test { // Disable TensorRT on the tests because of SegFault errors in the parser -TEST(TensorOpTest, Unsqueeze_1) { +TEST(UnsqueezeOpTest, Unsqueeze_1) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{1}); @@ -20,7 +20,7 @@ TEST(TensorOpTest, Unsqueeze_1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(TensorOpTest, Unsqueeze_1_int32) { +TEST(UnsqueezeOpTest, Unsqueeze_1_int32) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{1}); @@ -29,7 +29,7 @@ TEST(TensorOpTest, Unsqueeze_1_int32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(TensorOpTest, Unsqueeze_2) { +TEST(UnsqueezeOpTest, Unsqueeze_2) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{0, 4}); @@ -38,7 +38,7 @@ TEST(TensorOpTest, Unsqueeze_2) { test.Run(); } -TEST(TensorOpTest, Unsqueeze_3) { +TEST(UnsqueezeOpTest, Unsqueeze_3) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{2, 1, 0}); @@ -47,7 +47,7 @@ TEST(TensorOpTest, Unsqueeze_3) { test.Run(); } -TEST(TensorOpTest, Unsqueeze_scalar) { +TEST(UnsqueezeOpTest, Unsqueeze_scalar) { { OpTester test("Unsqueeze"); @@ -85,7 +85,7 @@ TEST(TensorOpTest, Unsqueeze_scalar) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_scalar_2) { +TEST(UnsqueezeOpTest, Unsqueeze_scalar_2) { { OpTester test("Unsqueeze"); @@ -105,7 +105,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_Duplicate) { +TEST(UnsqueezeOpTest, Unsqueeze_Duplicate) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute @@ -128,7 +128,7 @@ TEST(TensorOpTest, Unsqueeze_Duplicate) { } } -TEST(TensorOpTest, Unsqueeze_OutOfRange) { +TEST(UnsqueezeOpTest, Unsqueeze_OutOfRange) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute test.AddAttribute("axes", std::vector{4}); @@ -149,7 +149,7 @@ TEST(TensorOpTest, Unsqueeze_OutOfRange) { } } -TEST(TensorOpTest, UnsqueezeNegAxis_3) { +TEST(UnsqueezeOpTest, UnsqueezeNegAxis_3) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute test.AddAttribute("axes", std::vector{-4, 1, -6}); @@ -171,7 +171,7 @@ TEST(TensorOpTest, UnsqueezeNegAxis_3) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_1_int32_axes_input) { +TEST(UnsqueezeOpTest, Unsqueeze_1_int32_axes_input) { auto run_test = [](bool axes_is_initializer) { OpTester test("Unsqueeze", 13); @@ -185,7 +185,7 @@ TEST(TensorOpTest, Unsqueeze_1_int32_axes_input) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_3_axes_input) { +TEST(UnsqueezeOpTest, Unsqueeze_3_axes_input) { auto run_test = [](bool axes_is_initializer) { OpTester test("Unsqueeze", 13); @@ -200,7 +200,7 @@ TEST(TensorOpTest, Unsqueeze_3_axes_input) { } #if defined(USE_DNNL) -TEST(TensorOpTest, Unsqueeze_3_axes_input_bfloat16) { +TEST(UnsqueezeOpTest, Unsqueeze_3_axes_input_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; @@ -218,7 +218,7 @@ TEST(TensorOpTest, Unsqueeze_3_axes_input_bfloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(TensorOpTest, UnsqueezeNegAxis_3_bfloat16) { +TEST(UnsqueezeOpTest, UnsqueezeNegAxis_3_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 67fb35d26e6dc..559b521f18782 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -159,7 +159,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { // the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC // version of the ONNX operator when matching a static kernel, those are required. -#if !defined(DISABLE_CONTRIB_OPS) +#if !defined(DISABLE_CONTRIB_OPS) && !defined(USE_ROCM) TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "transform/fusion/conv_relu_opset12.onnx"; @@ -256,10 +256,6 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { run_test(ort_model_path); } -// This test can be deprecated now as the code logic has been changed so the model is not applicable -// TEST(InternalTestingEP, TestRegisterAllocatorHandlesUsageInMultipleSessions) { -//} - // make sure allocators returned by SessionState::GetAllocator are valid when IExecutionProvider::ReplaceAllocator // is used. if something is off InferenceSession::Initialize will fail. TEST(InternalTestingEP, TestReplaceAllocatorDoesntBreakDueToLocalAllocatorStorage) { diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 23ec48fa649dd..93e688570631e 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -42,8 +42,9 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #endif + const auto& logger = DefaultLoggingManager().DefaultLogger(); Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(), - {{domain_, opset_version_}}, {}, DefaultLoggingManager().DefaultLogger()); + {{domain_, opset_version_}}, {}, logger); std::vector input_args; std::unordered_map initializer_map; @@ -89,8 +90,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { ASSERT_STATUS_OK(graph.Resolve()); node.SetExecutionProviderType(ep_type); - OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), [](std::string const&) { return false; }, logger); const KernelCreateInfo* kernel_create_info = nullptr; ASSERT_STATUS_OK(info.TryFindKernel(&node, &kernel_create_info)); ASSERT_TRUE(kernel_create_info); @@ -139,7 +139,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { #pragma warning(disable : 6387) #endif OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs, outputs); - OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, DefaultLoggingManager().DefaultLogger()); + OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger); #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5db69489afaef..f1fbb1cea7ea2 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -51,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, @@ -82,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function& std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index e3f09e92593df..55177cc7ed131 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -131,11 +131,16 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { ExpectedEPNodeAssignment::All); } +// disabled for QNN 2.28.0.241029 failed for accuracy validation +// Also fails on QNN 2.28.2. +// qdq@QNN_EP val: 3.6094117164611816 (err: 1.3094117641448975, err/output_range: 22.19342041015625%) +// qdq@CPU_EP val: 2.2905881404876709 (err: 0.0094118118286132812, err/output_range: 0.15952222049236298%) +// abs(qdq@QNN_EP - qdq@CPU_EP) / output_range = 22.033897399902344% // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all // nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP. // // Static int32 indices with axis = 1 -TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { +TEST_F(QnnHTPBackendTests, DISABLED_GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), {utils::MakeAttribute("axis", static_cast(1))}, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 2af49a5e500d2..947ac19be40a8 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -189,8 +189,10 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_StaticBias_AU8_WU8_B } TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { - // QNN 2.24 LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide an - // explicit bias of all zeros to get around this bug. + // QNN 2.24 to 2.27: LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide + // an explicit bias of all zeros to get around this bug. + // QNN 2.28.0: Validation bug is fixed, but get accuracy errors. + // QNN 2.28.2: All fixed. for (size_t i = 0; i < 15; i++) { // Run it multiple times since this is an intermittent bug. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 1.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), @@ -201,8 +203,9 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { } } -// Test accuracy of 16-bit QDQ LayerNorm with a static scale input. TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { + // QNN 2.28.0: Get accuracy errors. + // QNN 2.28.2: All fixed. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static TestInputDef(), @@ -213,7 +216,7 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { // Test accuracy of 8-bit QDQ LayerNorm with a dynamic scale input. // -// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. +// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. Still fails on QNN SDK 2.28.2. // Verbose logs: // Starting stage: Graph Transformations and Optimizations // C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_to_vtcm diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 708aac03ceb2e..5c6967761b1db 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -273,7 +273,9 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { } // Test QDQ per-channel MatMul with int8 act, int4 weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { +// QNN 2.27 regression. Also fails on QNN 2.28.2. +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_PerChannel_AS8_WeightInt4) { std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 236b66a2d8a78..e8282dbad9f72 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1023,6 +1023,81 @@ TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { &ep_graph_checker); } +// Test option for offloading quantization of graph inputs and dequantization of graph outputs to the CPU EP. +TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { + // Returns a function that checks that the Q/DQ ops at the graph IO boundary are offloaded to CPU + // if the corresponding provider option is enabled. + auto graph_checker_builder = [](bool offload_graph_io_quantization) -> std::function { + return [offload_graph_io_quantization](const Graph& graph) { + size_t num_q = 0; + size_t num_dq = 0; + size_t num_qnn_fused_node = 0; + + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + + if (offload_graph_io_quantization && op_type == "QuantizeLinear") { + const bool consumes_graph_input = graph.IsInputsIncludingInitializers(node.InputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(consumes_graph_input); + num_q += 1; + } else if (offload_graph_io_quantization && op_type == "DequantizeLinear") { + const bool produces_graph_output = graph.IsOutput(node.OutputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(produces_graph_output); + num_dq += 1; + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + num_qnn_fused_node += 1; + } + } + + EXPECT_EQ(num_q, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_dq, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_qnn_fused_node, 1); + }; + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::vector op_types = { + "Sigmoid", + "Transpose", + "Softmax", + "Sqrt", + "Elu", + }; + + // Test various QDQ ops with offloading of I/O quantization enabled and disabled. + for (auto op_type : op_types) { + for (int offload_io_quant = 0; offload_io_quant <= 1; offload_io_quant++) { + provider_options["offload_graph_io_quantization"] = offload_io_quant ? "1" : "0"; + auto graph_checker = graph_checker_builder(offload_io_quant); + auto expected_ep_assignment = offload_io_quant ? ExpectedEPNodeAssignment::Some : ExpectedEPNodeAssignment::All; + + float min_val = (op_type == "Sqrt") ? 0.0f : -10.0f; + TestInputDef input_def({1, 2, 2, 2}, false, GetFloatDataInRange(min_val, 10.0f, 8)); + auto f32_model_build_fn = BuildOpTestCase(op_type, {input_def}, {}, {}); + auto qdq_model_build_fn = BuildQDQOpTestCase(op_type, {input_def}, {}, {}); + TestQDQModelAccuracy(f32_model_build_fn, + qdq_model_build_fn, + provider_options, + /*opset*/ 21, + expected_ep_assignment, + /*abs_err*/ QDQTolerance(), + logging::Severity::kERROR, + /*qnn_ctx_model_path*/ "", + /*session_option_pairs*/ {}, + &graph_checker); + } + } +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 8a4f7f2a1f6b5..4feeb5f830508 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -134,7 +134,8 @@ void InferenceModel(const std::string& model_data, const char* log_id, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep, - const std::unordered_map& session_option_pairs) { + const std::unordered_map& session_option_pairs, + std::function* graph_checker) { SessionOptions so; so.session_logid = log_id; for (auto key_value : session_option_pairs) { @@ -166,6 +167,10 @@ void InferenceModel(const std::string& model_data, const char* log_id, ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type; } + if (graph_checker) { + (*graph_checker)(graph); + } + const auto& outputs = graph.GetOutputs(); std::vector output_names; @@ -383,6 +388,7 @@ bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version) { {"ReduceMean", 18}, {"ReduceProd", 18}, {"ReduceSum", 13}, + {"ReduceL2", 18}, }; const auto it = opset_with_axes_as_input.find(op_type); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 7f55a44c748b6..a8670252ff9e0 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -457,13 +457,15 @@ DEF_QUANTIZE_VALUES_INT4_FUNC(UInt4x2, ParQuantizeLinearStdU4) * \param output_vals Initialized to the inference results. * \param is_qnn_ep Ture: QNN EP is used. False: CPU EP is used (default). * \param session_option_pairs extra session options. + * \param graph_checker Function called on the Graph. */ void InferenceModel(const std::string& model_data, const char* log_id, const ProviderOptions& provider_options, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep = false, - const std::unordered_map& session_option_pairs = {}); + const std::unordered_map& session_option_pairs = {}, + std::function* graph_checker = nullptr); /** * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies @@ -515,6 +517,8 @@ struct QDQTolerance { * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the QDQ model * on CPU EP. This tolerance is a percentage of the output range. * \param log_severity The logger's severity setting. + * \param ep_graph_checker Function called on the Graph generated for the QNN EP's session. Used to check node + * EP assignment. */ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, @@ -523,7 +527,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe QDQTolerance tolerance = QDQTolerance(), logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", - const std::unordered_map& session_option_pairs = {}) { + const std::unordered_map& session_option_pairs = {}, + std::function* qnn_ep_graph_checker = nullptr) { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; @@ -607,7 +612,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP InferenceModel(qdq_model_data, "qdq_model_logger", qnn_options, expected_ep_assignment, - qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); + qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs, qnn_ep_graph_checker); } if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index 13173d9a87f55..e4abe85908373 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -309,6 +309,27 @@ TEST_F(QnnCPUBackendTests, ReduceMeanOpset13) { ExpectedEPNodeAssignment::All); } +// +// ReduceL2 +// +TEST_F(QnnCPUBackendTests, ReduceL2Opset18) { + RunReduceTest("ReduceL2", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, ReduceL2Opset13) { + RunReduceTest("ReduceL2", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Test creates a graph with a ReduceSum node, and checks that all nodes are supported by the QNN EP diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 018720fd8b71f..7541d94bac0c6 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -229,8 +229,16 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { ExpectedEPNodeAssignment::All); } +// disabled for QNN 2.28.0.241029 backendValidateOpConfig failed +// still fails on QNN 2.28.2. +// QnnDsp [4294967295] has incorrect Value -32768, expected equal to 0. +// QnnDsp validateNativeOps node_token_6:qti.aisw:Tanh htp op validator failed 3110 +// QnnDsp registered validator failed => 3110 +// QnnDsp QnnBackend_validateOpConfig failed 3110 +// QnnDsp Wake up free backend (id: 1)'s thread(s) +// QnnDsp Failed to validate op node_token_6 with error 0xc26 // Tests accuracy of 16-bit QDQ Tanh. -TEST_F(QnnHTPBackendTests, UnaryOp_Tanh_U16) { +TEST_F(QnnHTPBackendTests, DISABLED_UnaryOp_Tanh_U16) { RunQDQOpTest("Tanh", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, {}, diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 63327a028c6f4..0022d7fc0e184 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -342,8 +342,12 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { Graph& graph = model->MainGraph(); GraphViewer viewer(graph); + std::string trt_version = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); + std::string cuda_version = std::to_string(CUDA_VERSION); + std::string ort_version = ORT_VERSION; + // get the hash for the model when loaded from file - HashValue model_hash = TRTGenerateId(viewer); + HashValue model_hash = TRTGenerateId(viewer, trt_version, cuda_version); ASSERT_NE(model_hash, 0); // now load the model from bytes and check the hash differs @@ -358,7 +362,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { // Test loading same model from file and byte steam. Hash values should be different Graph& graph2 = model2->MainGraph(); GraphViewer viewer2(graph2); - HashValue model_hash2 = TRTGenerateId(viewer2); + HashValue model_hash2 = TRTGenerateId(viewer2, trt_version, cuda_version); ASSERT_NE(model_hash, model_hash2); // Test loading same model from different path, see if hash values are same as well @@ -367,7 +371,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_TRUE(Model::Load(model_path, model3, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph3 = model3->MainGraph(); GraphViewer viewer3(graph3); - HashValue model_hash3 = TRTGenerateId(viewer3); + HashValue model_hash3 = TRTGenerateId(viewer3, trt_version, cuda_version); ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 9b1e87f6ec02e..8fc76da3495a8 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -152,6 +152,9 @@ def create_backend_test(test_name=None): if backend.supports_device("MIGRAPHX"): current_failing_tests += apply_filters(filters, "current_failing_tests_MIGRAPHX") + if backend.supports_device("WEBGPU"): + current_failing_tests += apply_filters(filters, "current_failing_tests_WEBGPU") + # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA # EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precedence than DML # and the nodes are assigned to only the CUDA EP (which supports these tests) diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 29680c98fb4de..2f8fb84c4c651 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -28,7 +28,6 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference import unittest -from pathlib import Path def unique_element(lst): @@ -41,6 +40,8 @@ def unique_element(lst): class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): + from pathlib import Path + cwd = os.getcwd() test_model_dir = os.path.join(cwd, "..", "models") for filename in Path(test_model_dir).rglob("*.onnx"): diff --git a/onnxruntime/test/python/onnxruntime_test_python_tvm.py b/onnxruntime/test/python/onnxruntime_test_python_tvm.py deleted file mode 100644 index 0080bf53520f2..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_python_tvm.py +++ /dev/null @@ -1,242 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Module for unit testing of TVM EP -""" - -import os -import sys -import tempfile -import unittest -from typing import Any, AnyStr, Dict, List, Tuple - -import numpy -import tvm -from numpy.testing import assert_almost_equal -from onnx import ModelProto, TensorProto, mapping -from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info - -import onnxruntime - -numpy.random.seed(32) - - -def is_windows(): - """ - Function to determine the Windows system - """ - return sys.platform.startswith("win") - - -def get_model_with_dynamic_shapes() -> ModelProto: - """ - Create model with Dynamic Shapes - """ - x = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - node1 = make_node("MatMul", ["X", "A"], ["XA"]) - node2 = make_node("Add", ["XA", "B"], ["Y"]) - graph = make_graph([node1, node2], "lr", [x, a, b], [y]) - onnx_model = make_model(graph) - return onnx_model - - -def get_model_with_fixed_shapes() -> ModelProto: - """ - Create model with Static Shapes - """ - - def change_input_shape(model: ModelProto, ind: int, shape: Tuple) -> None: - """ - Function to change the input form - """ - dims = model.graph.input[ind].type.tensor_type.shape.dim - assert len(dims) == len(shape), "Input rank and new shape rank do not match." - for i, new_dim in enumerate(shape): - model.graph.input[ind].type.tensor_type.shape.dim[i].dim_value = new_dim - - dynamic_model = get_model_with_dynamic_shapes() - change_input_shape(dynamic_model, 0, (1, 2)) # X - change_input_shape(dynamic_model, 1, (2, 2)) # A - change_input_shape(dynamic_model, 2, (1, 2)) # B - return dynamic_model - - -def get_input_data_for_model_with_dynamic_shapes() -> Dict[AnyStr, numpy.ndarray]: - """ - Create input data for model with dynamic shapes - """ - a = numpy.random.randn(2, 2).astype(numpy.float32) # pylint: disable=invalid-name - b = numpy.random.randn(1, 2).astype(numpy.float32) # pylint: disable=invalid-name - x = numpy.random.randn(1, 2).astype(numpy.float32) # pylint: disable=invalid-name - data = {"A": a, "B": b, "X": x} - return data - - -def get_input_data_for_model_with_fixed_shapes(onnx_model: ModelProto) -> Dict[AnyStr, numpy.ndarray]: - """ - Create input data for model with static shapes - """ - - def get_onnx_input_names(model: ModelProto) -> List[AnyStr]: - inputs = [node.name for node in model.graph.input] - initializer = [node.name for node in model.graph.initializer] - inputs = list(set(inputs) - set(initializer)) - return sorted(inputs) - - def get_onnx_input_types(model: ModelProto) -> List[numpy.dtype]: - input_names = get_onnx_input_names(model) - return [ - mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type] - for node in sorted(model.graph.input, key=lambda node: node.name) - if node.name in input_names - ] - - def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]: - input_names = get_onnx_input_names(model) - return [ - [dv.dim_value for dv in node.type.tensor_type.shape.dim] - for node in sorted(model.graph.input, key=lambda node: node.name) - if node.name in input_names - ] - - input_names = get_onnx_input_names(onnx_model) - input_shapes = get_onnx_input_shapes(onnx_model) - input_types = get_onnx_input_types(onnx_model) - assert len(input_names) == len(input_types) == len(input_shapes) - random_inputs = [numpy.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)] - return dict(zip(input_names, random_inputs)) - - -def get_input_names_and_shapes(data: Dict[AnyStr, numpy.ndarray]) -> Tuple[List[AnyStr], List[AnyStr]]: - """ - Create text representations for model input names and shapes - """ - keys = list(data.keys()) - values = [data[key] for key in keys] - return ( - list(data.keys()), - [str(value.shape).replace(",", "").replace("(", "[").replace(")", "]") for value in values], - ) - - -def get_cpu_output(onnx_model: ModelProto, data: Dict[AnyStr, numpy.ndarray]) -> List[numpy.ndarray]: - """ - Run inference with CPUExecutionProvider - """ - # pylint: disable=no-member - sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), - providers=["CPUExecutionProvider"], - ) - output = sess.run(None, data) - return output - - -def get_tvm_output( - onnx_model: ModelProto, data: Dict[AnyStr, numpy.ndarray], provider_options: Dict[AnyStr, Any] -) -> List[numpy.ndarray]: - """ - Run inference with TVMExecutionProvider - """ - session_options = onnxruntime.SessionOptions() # pylint: disable=no-member - session_options.log_severity_level = 0 - session_options.log_verbosity_level = 0 - # pylint: disable=no-member - session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - - sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), - session_options, - providers=["TvmExecutionProvider"], - provider_options=[provider_options], - ) - - output = sess.run(None, data) - return output - - -# pylint: disable=no-member -def compile_virtual_machine(model: ModelProto, target_str: AnyStr) -> tvm.runtime.vm.Executable: - """ - Compile ONNX model using VirtualMachine - """ - ir_mod, _ = tvm.relay.frontend.from_onnx( - model, - opset=model.opset_import[0].version, - freeze_params=True, - ) - target = tvm.target.Target(target=target_str, host=target_str) - return tvm.relay.backend.vm.compile(ir_mod, target) - - -def serialize_virtual_machine(vm_exec: tvm.runtime.vm.Executable) -> AnyStr: - """ - Serialize VirtualMachine - """ - temp_directory = tempfile.mkdtemp() - path_consts = os.path.join(temp_directory, "consts") - vm_exec.move_late_bound_consts(path_consts, byte_limit=256) - lib_path = os.path.join(temp_directory, f"model.{'dll' if is_windows() else 'so'}") - code_path = os.path.join(temp_directory, "model.ro") - code, lib = vm_exec.save() - lib.export_library(lib_path) - with open(code_path, "wb") as code_file: - code_file.write(code) - return temp_directory - - -class TestTVM(unittest.TestCase): - """ - Unit tests for TVM EP - """ - - @staticmethod - def test_accuracy_for_model_with_dynamic_shapes(): - """ - Accuracy test for model with dynamic shapes - """ - onnx_model = get_model_with_dynamic_shapes() - data = get_input_data_for_model_with_dynamic_shapes() - - cpu_output = get_cpu_output(onnx_model, data) - names, shapes = get_input_names_and_shapes(data) - provider_options = dict( - target="llvm", - input_names=" ".join(names), - input_shapes=" ".join(shapes), - ) - tvm_output = get_tvm_output(onnx_model, data, provider_options) - - assert_almost_equal(cpu_output, tvm_output, decimal=5) - - @staticmethod - def test_accuracy_for_tvm_so(): - """ - Accuracy test for TVMso Ep - """ - onnx_model = get_model_with_fixed_shapes() - data = get_input_data_for_model_with_fixed_shapes(onnx_model) - - cpu_output = get_cpu_output(onnx_model, data) - - compiled_vm_exec = compile_virtual_machine(onnx_model, target_str="llvm") - so_folder = serialize_virtual_machine(compiled_vm_exec) - provider_options = dict( - target="llvm", - so_folder=so_folder, - ) - tvm_output = get_tvm_output(onnx_model, data, provider_options) - - assert_almost_equal(cpu_output, tvm_output, decimal=5) - - -if __name__ == "__main__": - if "TvmExecutionProvider" not in onnxruntime.get_available_providers(): - raise AssertionError(f"Unable to find 'TvmExecutionProvider' in {onnxruntime.get_available_providers()}") - unittest.main() diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index cf7fc292ea86b..82193d08684c6 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -1,3 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + import uuid from pathlib import Path @@ -661,3 +668,29 @@ def generate_random_initializer(initializer_name, tensor_shape, tensor_dtype, me tensor = np.random.normal(mean, dev, tensor_shape).astype(tensor_dtype) init = onnx.numpy_helper.from_array(tensor, initializer_name) return init + + +def get_tensor_consumers_and_producers( + model: onnx.ModelProto, +) -> tuple[dict[str, list[onnx.NodeProto]], dict[str, onnx.NodeProto]]: + """ + Returns a tuple containing the following python dictionaries: + - consumers: maps a tensor name to the list of nodes that have that tensor as an input. + - producers: maps a tensor name to the node that generates this tensor as an output. + """ + consumers: dict[str, list[onnx.NodeProto]] = {} + producers: dict[str, onnx.NodeProto] = {} + for node in model.graph.node: + # Iterate through node's inputs to build the consumers dictionary. + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + # Iterate through node's outputs to build the producers dictionary. + for output_name in node.output: + producers[output_name] = node + + return (consumers, producers) diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py new file mode 100644 index 0000000000000..58d00272475cd --- /dev/null +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, get_qdq_config, quantize + + +class TestGetQDQConfig(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.int_qdq_config_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_add_model( + self, + shape: list[int], + tensor_type: onnx.TensorProto.DataType, + weight: onnx.TensorProto | None = None, + opset: int = 21, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Add operator. The second input can be optionally made + a static weight. + """ + graph_inputs = [onnx.helper.make_tensor_value_info("input_0", tensor_type, shape)] + graph_outputs = [onnx.helper.make_tensor_value_info("output_0", tensor_type, shape)] + initializers = [] + add_input_names = ["input_0"] + + if weight is not None: + initializers.append(weight) + add_input_names.append(weight.name) + else: + graph_inputs.append(onnx.helper.make_tensor_value_info("input_1", tensor_type, shape)) + add_input_names.append("input_1") + + add_node = onnx.helper.make_node("Add", add_input_names, ["output_0"], name="Add0") + + graph = onnx.helper.make_graph( + [add_node], + "AddGraph", + graph_inputs, + graph_outputs, + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", opset)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_basic_args(self): + """ + Test that get_qdq_config() returns a config that sets the basic args. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + calibrate_method=CalibrationMethod.Percentile, + calibrate_args={"percentile": 99.98}, # Converted to extra_options + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt16, + per_channel=True, + reduce_range=True, + nodes_to_exclude=["Mul"], + # Other options converted to extra_options: + min_real_range=0.0001, + keep_removable_activations=True, + activation_symmetric=True, + weight_symmetric=True, + ) + self.assertEqual(qdq_config.calibrate_method, CalibrationMethod.Percentile) + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertEqual(qdq_config.weight_type, QuantType.QInt16) + self.assertTrue(qdq_config.per_channel) + self.assertTrue(qdq_config.reduce_range) + self.assertEqual(set(qdq_config.nodes_to_exclude), {"Mul"}) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + + # Check that calibration args are translated to extra_options. + self.assertEqual(qdq_config.extra_options["CalibPercentile"], 99.98) + + # Check that other args are also translated to extra_options. + self.assertEqual(qdq_config.extra_options["MinimumRealRange"], 0.0001) + self.assertTrue(qdq_config.extra_options["QDQKeepRemovableActivations"]) + self.assertTrue(qdq_config.extra_options["ActivationSymmetric"]) + self.assertTrue(qdq_config.extra_options["WeightSymmetric"]) + + # The following options should always be set to specific values. + self.assertTrue(qdq_config.extra_options["ForceQuantizeNoInputCheck"]) + self.assertEqual(qdq_config.quant_format, QuantFormat.QDQ) + + # Should use onnx domain Q/DQ ops because onnx opset >= 21. + self.assertFalse(qdq_config.extra_options.get("UseQDQContribOps", False)) + + def test_exclude_nodes_callable(self): + """ + Test passing a function/callable to exclude nodes from quantization. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Local function that excludes all "Add" nodes. + def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + return node.op_type == "Add" + + qdq_config = get_qdq_config( + float_model, + data_reader, + nodes_to_exclude=should_exclude_node_, + ) + + expected_excluded_nodes = set([node.name for node in float_model.graph.node if node.op_type == "Add"]) + self.assertTrue(bool(expected_excluded_nodes)) + self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) + + def test_external_data(self): + """ + Test that get_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + shape = [1, 32, 32] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + large_weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, large_weight) + float_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.onnx") + + onnx.save_model( + float_model, + float_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data_int_qdq_config.bin", + ) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(0, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Create a quantization config and check that it sets boolean to use external data + qdq_config = get_qdq_config( + float_model_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8 + ) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + self.assertTrue(qdq_config.use_external_data_format) + + # Quantize the model and check computational correctness against float model. + qdq_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.qdq.onnx") + quantize(float_model_path, qdq_model_path, qdq_config) + + expected_op_counts = {"DequantizeLinear": 3, "QuantizeLinear": 2, "Add": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + # The quantized weight should still be stored in an external file. + qdq_model = onnx.load_model(qdq_model_path, load_external_data=False) + weight_quantized = next( + ( + initializer + for initializer in qdq_model.graph.initializer + if initializer.name == f"{large_weight.name}_quantized" + ), + None, + ) + self.assertIsNotNone(weight_quantized) + self.assertEqual(weight_quantized.data_location, onnx.TensorProto.EXTERNAL) + + def test_use_qdq_contrib_ops_for_int16_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int16 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt8, + ) + + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + def test_use_qdq_contrib_ops_for_int4_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int4 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Use int4 in tensor quantization overrides. This should still force use of 'com.microsoft' Q/DQ ops. + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + tensor_quant_overrides={"weight": [{"quant_type": QuantType.QInt4}]}, + ) + + self.assertEqual(qdq_config.extra_options["TensorQuantOverrides"]["weight"][0]["quant_type"], QuantType.QInt4) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 291bf42405d58..755c7fae5e3e8 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -4,14 +4,23 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations import itertools +import os +import tempfile import unittest import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + check_qtype_by_node_type, + get_tensor_consumers_and_producers, +) from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static @@ -519,5 +528,160 @@ def test_pad_with_empty_string_input_name(self): self.assertNotEqual(name, "_quantized") +class TestQDQPad(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_pad_model( + self, + mode: str, + constant_value: float | None = None, + opset: int = 21, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ) -> onnx.ModelProto: + num_pads_start = 1 + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2)) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 2 + num_pads_start)) + + initializers = [] + pad_input_names = ["input_0"] + attrs = {"mode": mode} + + pads_data = np.array([0, num_pads_start, 0, 0], dtype=np.int64) # Pad one val at beginning of axis 1. + if opset >= 11: + initializers.append(onnx.numpy_helper.from_array(pads_data, "pads")) + pad_input_names.append("pads") + else: + attrs["pads"] = pads_data.tolist() + + if mode == "constant" and constant_value is not None: + if opset >= 11: + initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value])) + pad_input_names.append("constant_value") + else: + attrs["value"] = float(constant_value) + + pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs) + + graph = onnx.helper.make_graph( + [pad_node], + "PadFloat", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", opset)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_pad_qparams(self): + """ + Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations. + """ + test_configs = [ + # Opset 21 + ("constant", None, 21, onnx.TensorProto.FLOAT), + ("constant", None, 21, onnx.TensorProto.FLOAT16), + ("constant", 0, 21, onnx.TensorProto.FLOAT), + ("constant", 0, 21, onnx.TensorProto.FLOAT16), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT16), + ("reflect", None, 21, onnx.TensorProto.FLOAT), + ("reflect", None, 21, onnx.TensorProto.FLOAT16), + ("edge", None, 21, onnx.TensorProto.FLOAT), + ("edge", None, 21, onnx.TensorProto.FLOAT16), + ("wrap", None, 21, onnx.TensorProto.FLOAT), + ("wrap", None, 21, onnx.TensorProto.FLOAT16), + # Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs. + # Opset 10 Q/DQ ops don't support float16. + ("constant", None, 10, onnx.TensorProto.FLOAT), + ("constant", 0, 10, onnx.TensorProto.FLOAT), + ("constant", 10.0, 10, onnx.TensorProto.FLOAT), + ("reflect", None, 10, onnx.TensorProto.FLOAT), + ("edge", None, 10, onnx.TensorProto.FLOAT), + ] + + for pad_mode, constant_value, opset, float_type in test_configs: + with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type): + label = f"_{pad_mode}_{constant_value}_opset{opset}_{onnx.TensorProto.DataType.Name(float_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx") + + float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input_data_list = [ + {"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)}, + {"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + ) + + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1} + if constant_value is not None and opset >= 11: + expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized. + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + if pad_mode != "reflect": + # Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does + # not match the ONNX reference implementation. See the following issue: + # https://github.com/microsoft/onnxruntime/issues/20801 + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + quant_output_same_as_input = False + + if pad_mode in ("reflect", "edge", "wrap"): + quant_output_same_as_input = True + + if pad_mode == "constant" and constant_value in (None, 0): + quant_output_same_as_input = True + + pad_node = next((node for node in qdq_model.graph.node if node.op_type == "Pad"), None) + self.assertNotEqual(pad_node, None) + self.assertEqual(pad_node.op_type, "Pad") + + # Get the parent and child nodes of the Pad and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(pad_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(pad_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q. + if quant_output_same_as_input: + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point + else: + self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_slice.py b/onnxruntime/test/python/quantization/test_op_slice.py new file mode 100644 index 0000000000000..bfb9fc6b46bbd --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_slice.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + get_tensor_consumers_and_producers, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestQDQSlice(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.slice_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_slice_model( + self, + input_shape: list[int], + input_tensor_type: onnx.TensorProto.DataType, + starts: list[int], + ends: list[int], + axes: list[int] | None = None, + steps: list[int] | None = None, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Slice operator. + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", input_tensor_type, input_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", input_tensor_type, None) + + initializers = [ + onnx.numpy_helper.from_array(np.array(starts, dtype=np.int64), "starts"), + onnx.numpy_helper.from_array(np.array(ends, dtype=np.int64), "ends"), + ] + slice_input_names = ["input_0", "starts", "ends"] + + if axes: + initializers.append(onnx.numpy_helper.from_array(np.array(axes, dtype=np.int64), "axes")) + slice_input_names.append("axes") + + if steps: + if not axes: + slice_input_names.append("") # Empty axes input. + initializers.append(onnx.numpy_helper.from_array(np.array(steps, dtype=np.int64), "steps")) + slice_input_names.append("steps") + + slice_node = onnx.helper.make_node("Slice", slice_input_names, ["output_0"], name="Slice0") + + graph = onnx.helper.make_graph( + [slice_node], + "SliceGraph", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_slice_qparams(self): + """ + Test that QDQ Slice has equal scale/zero-point for its input and output. + """ + test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16] + + for onnx_tensor_type in test_configs: + with self.subTest(onnx_tensor_type=onnx_tensor_type): + label = f"{onnx.TensorProto.DataType.Name(onnx_tensor_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.qdq.onnx") + + input_shape = [2, 4] + float_model = self.build_slice_model( + input_shape=input_shape, + input_tensor_type=onnx_tensor_type, + starts=[1, 0], + ends=[2, 3], + axes=None, + steps=[1, 2], + ) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type) + input_data_list = [ + {"input_0": np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=np_dtype)}, + {"input_0": np.array([[-1.0, -2.0, -3.0, -4.0], [-5.0, -6.0, -7.0, -8.0]], dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + extra_options={"ForceQuantizeNoInputCheck": True}, + ) + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Slice": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + + slice_node = next((node for node in qdq_model.graph.node if node.op_type == "Slice"), None) + self.assertNotEqual(slice_node, None) + self.assertEqual(slice_node.op_type, "Slice") + + # Get the parent and child nodes of the Slice and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(slice_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(slice_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Slice's input DQ uses the same scale/zp as the Slice's output Q. + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_softmax.py b/onnxruntime/test/python/quantization/test_op_softmax.py index 3416198450137..e5bc6288c91e2 100644 --- a/onnxruntime/test/python/quantization/test_op_softmax.py +++ b/onnxruntime/test/python/quantization/test_op_softmax.py @@ -213,6 +213,40 @@ def test_quantize_softmax(self): self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8) self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8) + def test_bug_fix_exclude_softmax(self): + """ + Test fix to bug that happens when softmax is excluded from quantization, but + the quantization tool still tries to assign it a tensor range of [0.0, 1.0]. + """ + np.random.seed(1) + model_fp32_path = "softmax_fp32.onnx" + model_qdq_path = "softmax_bug_exclude_softmax.qdq.onnx" + self.construct_model_conv_softmax( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"axis": -2}, + [1, 3, 24, 40], + add_ms_domain_opset=False, + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + data_reader.rewind() + + # Bug would cause an exception during quantization. + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + nodes_to_exclude=["Softmax"], + ) + + qdq_model = onnx.load(Path(model_qdq_path)) + self.assertIn("Softmax", {node.op_type for node in qdq_model.graph.node}) + def test_quantize_softmax_s8s8(self): self.quantize_softmax_test_qop( QuantType.QInt8, diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index b99c11abf6d2c..23b397ffd80e1 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -20,10 +20,12 @@ check_op_type_count, check_op_type_order, create_clip_node, + get_tensor_consumers_and_producers, ) from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static, write_calibration_table from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData +from onnxruntime.quantization.quant_utils import quantize_nparray class TestQDQFormat(unittest.TestCase): @@ -1726,5 +1728,479 @@ def test_json_serialization(self): write_calibration_table(new_calibrate_tensors_range) +class TestAdjustWeightScaleForInt32Bias(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.adj_int32_bias_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + + tiny_value = 1e-7 if np_float_type == np.float32 else 0.007782 + # weight_scale = 2*tiny_value / 255.0 = 7.84313725490196e-10 + + weight_data = np.full(weight_shape, tiny_value, dtype=np_float_type) + with np.nditer(weight_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = -x + + weight = onnx.numpy_helper.from_array(weight_data, "weight") + + # if we set input_scale to 0.05, then normally bias_scale would be + # (input_scale * weight_scale) => (0.05 * 7.84314e-10) => 3.9215686274509805e-11 + # + # If we quantize the f32 bias with this bias_scale, we get + # [5.0/bias_scale, 4.0/bias_scale] = [127500000000, 102000000000]. These quantized bias values exceed the + # range of int32. + # + # The ORT quantization tool will clamp these out-of-bounds values to int32::max(), + # which can be very inaccurate. + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + with np.nditer(bias_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = 5.0 if np_float_type == np.float32 else 1400 + else: + x[...] = -4.5 if np_float_type == np.float32 else -1200 + + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convfloat", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_adjust_weight_scale_for_int32_bias(self): + """ + Test adjustment of weight input's scale to ensure int32 bias's scale is not too small. + """ + test_configs = [ + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT16, True), + (onnx.TensorProto.FLOAT16, False), + ] + + for float_type, per_channel in test_configs: + with self.subTest(float_type=float_type, per_channel=per_channel): + label = f"_f{float_type}_perchannel{per_channel}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.qdq.onnx") + + # Create float model with a Conv that has tiny weight values. + # This tiny weight scale would normally create a very small bias scale that will saturate + # bias's int32 range. But, the qdq_quantizer adjusts the weight's scale to ensure this doesn't happen. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_conv_test_model(input0_shape, weight_shape, float_type) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input0_rmin = 0.0 + input0_scale = 0.05 if float_type == onnx.TensorProto.FLOAT else 0.01 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np_float_type)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + per_channel=per_channel, + ) + + # Check correctness + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + def build_model_convs_share_bias( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx_float_type, None) + + weight_0_data = np.ones(weight_shape, dtype=np_float_type) + weight_0 = onnx.numpy_helper.from_array(weight_0_data, "weight_0") + + weight_1_data = np.full(weight_shape, 0.5, dtype=np_float_type) + weight_1 = onnx.numpy_helper.from_array(weight_1_data, "weight_1") + + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + bias_shared = onnx.numpy_helper.from_array(bias_data, "bias_shared") + + conv_0_node = onnx.helper.make_node("Conv", ["input_0", "weight_0", "bias_shared"], ["output_0"], name="Conv0") + conv_1_node = onnx.helper.make_node("Conv", ["input_0", "weight_1", "bias_shared"], ["output_1"], name="Conv1") + graph = onnx.helper.make_graph( + [conv_0_node, conv_1_node], + "ConvWithSharedBiasToDup", + [input_0], + [output_0, output_1], + initializer=[weight_0, weight_1, bias_shared], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_dup_shared_bias(self): + """ + Test duplicating a bias that is shared by two nodes that want to quantize their bias to int32. + """ + float_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.qdq.onnx") + + # Create float model with a Convs that share a bias input. The QDQ quantizer should add a + # duplicate bias so that each node has its own. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_model_convs_share_bias(input0_shape, weight_shape, onnx.TensorProto.FLOAT) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + input0_rmin = 0.0 + input0_scale = 0.05 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np.float32)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np.float32)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np.float32)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + ) + + qdq_model = onnx.load_model(qdq_model_path) + bias_names = set() + + for node in qdq_model.graph.node: + if node.op_type == "DequantizeLinear" and node.input[0].startswith("bias_shared"): + bias_names.add(node.input[0]) + + self.assertEqual(len(bias_names), 2) + + +class TestQDQPrequantWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.prequant_weight") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_model( + self, + inp_shape: list[int], + weight_quant_data: np.ndarray, + weight_scale_data: np.ndarray, + weight_zp_data: np.ndarray, + bias_data: np.ndarray, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ): + """ + Builds a model with a Conv that has a pre-quantized constant weight input. + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None) + weight_quant = onnx.numpy_helper.from_array(weight_quant_data, "weight_quant") + weight_scale = onnx.numpy_helper.from_array(weight_scale_data, "weight_scale") + weight_zp = onnx.numpy_helper.from_array(weight_zp_data, "weight_zp") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + dq_node = onnx.helper.make_node( + "DequantizeLinear", ["weight_quant", "weight_scale", "weight_zp"], ["weight_dequant"], name="DQ0" + ) + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight_dequant", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [dq_node, conv_node], + "ConvPreQuantWeight", + [input_0], + [output_0], + initializer=[weight_quant, weight_scale, weight_zp, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def build_conv_dynamic_weight_model( + self, + input_quant_data: np.ndarray, + input_scale_data: np.ndarray, + input_zp_data: np.ndarray, + weight_shape: list[int], + bias_data: np.ndarray, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ): + """ + Builds a model with a Conv that has a dynamic float weight input, but a constant + pre-quantized input[0]. + """ + dyn_weight = onnx.helper.make_tensor_value_info("dyn_weight", float_type, weight_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None) + input_quant = onnx.numpy_helper.from_array(input_quant_data, "input_quant") + input_scale = onnx.numpy_helper.from_array(input_scale_data, "input_scale") + input_zp = onnx.numpy_helper.from_array(input_zp_data, "input_zp") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + dq_node = onnx.helper.make_node( + "DequantizeLinear", ["input_quant", "input_scale", "input_zp"], ["input_dequant"], name="DQ0" + ) + conv_node = onnx.helper.make_node("Conv", ["input_dequant", "dyn_weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [dq_node, conv_node], + "ConvPreQuantInput_DynamicWeight", + [dyn_weight], + [output_0], + initializer=[input_quant, input_scale, input_zp, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def test_quantize_with_prequantized_weights(self): + """ + Test quantization of Conv with pre-quantized weights. + """ + rng = np.random.default_rng(123) + test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16] + + for float_type in test_configs: + with self.subTest(float_type=float_type): + label = f"_{onnx.TensorProto.DataType.Name(float_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_weight{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_weight{label}.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + + # range = 2.0, scale = 2/254, zp = 0 + weight_scale_data = np.array(2 / 254, dtype=np_dtype) + weight_zp_data = np.array(0, dtype=np.int8) + weight_data = np.linspace(-1.0, 1.0, num=1600, dtype=np_dtype).reshape(weight_shape) + weight_quant_data = quantize_nparray( + onnx.TensorProto.INT8, weight_data, weight_scale_data, weight_zp_data + ) + + bias_data = np.array([-10.0, 10.0], dtype=np_dtype) + float_model = self.build_conv_model( + inp_shape, weight_quant_data, weight_scale_data, weight_zp_data, bias_data, float_type + ) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + # Check that the input model only has a pre-quantized weight and save its scale/zero-point + # to check that it doesn't change after quantization. + float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1} + check_op_type_count(self, float_model_path, **float_node_counts) + conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node_original, None) + + _, producers_original = get_tensor_consumers_and_producers(float_model) + weight_dq_node_original = producers_original.get(conv_node_original.input[1], None) + initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer} + scale_name_original = weight_dq_node_original.input[1] + scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original]) + zp_name_original = weight_dq_node_original.input[2] + zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original]) + + input_data_list = [ + {"input_0": rng.uniform(-10.0, 10.0, inp_shape).astype(np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + op_types_to_quantize=["Conv"], + ) + + # The final model should have everything quantized + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + # Check that the pre-quantized weight still has the same scale/zp after quantization + qdq_model = onnx.load_model(qdq_model_path) + conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node, None) + + _, producers = get_tensor_consumers_and_producers(qdq_model) + weight_dq_node = producers.get(conv_node.input[1], None) + initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer} + + scale_name = weight_dq_node.input[1] + self.assertEqual(scale_name, scale_name_original) + scale_val = onnx.numpy_helper.to_array(initializers[scale_name]) + self.assertEqual(scale_val, scale_val_original) + + zp_name = weight_dq_node.input[2] + self.assertEqual(zp_name, zp_name_original) + zp_val = onnx.numpy_helper.to_array(initializers[zp_name]) + self.assertEqual(zp_val, zp_val_original) + + def test_quantize_with_prequantized_input(self): + """ + Test quantization of Conv with pre-quantized input and dynamic weight. + """ + rng = np.random.default_rng(123) + test_configs = [ + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT16, False), + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT16, True), + ] + + for float_type, convert_weight_qtype in test_configs: + with self.subTest(float_type=float_type): + convert_label = "_convert_qtype" if convert_weight_qtype else "" + label = f"_{onnx.TensorProto.DataType.Name(float_type)}{convert_label}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_input{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_input{label}.qdq.onnx") + + inp_shape = [1, 2, 40, 40] + weight_shape = [2, 2, 20, 20] + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + + # range = 3.0, scale = 3/255, zp = 127 + input_scale_data = np.array(3 / 255, dtype=np_dtype) + input_zp_data = np.array(127, dtype=np.uint8) + input_data = np.linspace(-1.5, 1.5, num=3200, dtype=np_dtype).reshape(inp_shape) + input_quant_data = quantize_nparray(onnx.TensorProto.UINT8, input_data, input_scale_data, input_zp_data) + + bias_data = np.array([-10.0, 10.0], dtype=np_dtype) + float_model = self.build_conv_dynamic_weight_model( + input_quant_data, input_scale_data, input_zp_data, weight_shape, bias_data, float_type + ) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + # Check that the input model only has a pre-quantized input and save its scale/zero-point + # to check that it doesn't change after quantization. + float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1} + check_op_type_count(self, float_model_path, **float_node_counts) + conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node_original, None) + + _, producers_original = get_tensor_consumers_and_producers(float_model) + input_dq_node_original = producers_original.get(conv_node_original.input[0], None) + initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer} + scale_name_original = input_dq_node_original.input[1] + scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original]) + zp_name_original = input_dq_node_original.input[2] + zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original]) + + # Create data reader with random input calibration data. + dyn_weight_data_list = [ + {"dyn_weight": rng.uniform(-10.0, 10.0, weight_shape).astype(np_dtype)}, + ] + data_reader = TestDataFeeds(dyn_weight_data_list) + + extra_options = {} + if convert_weight_qtype: + # Test converting the dynamic weight's quantization type, which results in + # dyn_weight -> Q(u16) -> DQ(f32) -> Q(u8) -> DQ(f32) -> Conv + extra_options["TensorQuantOverrides"] = { + "dyn_weight": [{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8}}], + } + + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + op_types_to_quantize=["Conv"], + extra_options=extra_options, + ) + + # The final model should have everything quantized + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + if convert_weight_qtype: + qdq_node_counts["QuantizeLinear"] += 1 + qdq_node_counts["DequantizeLinear"] += 1 + + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + # Check that the pre-quantized input still has the same scale/zp after quantization + qdq_model = onnx.load_model(qdq_model_path) + conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node, None) + + _, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(conv_node.input[0], None) + initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer} + + scale_name = input_dq_node.input[1] + self.assertEqual(scale_name, scale_name_original) + scale_val = onnx.numpy_helper.to_array(initializers[scale_name]) + self.assertEqual(scale_val, scale_val_original) + + zp_name = input_dq_node.input[2] + self.assertEqual(zp_name, zp_name_original) + zp_val = onnx.numpy_helper.to_array(initializers[zp_name]) + self.assertEqual(zp_val, zp_val_original) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 96d841654adbd..b23d53f2a04e8 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -145,7 +145,7 @@ def test_quantize_data_4bit(self): for onnx_type, symmetric in subtest_configs: with self.subTest(onnx_type=onnx_type, symmetric=symmetric): - _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) is_signed = onnx_type == onnx.TensorProto.INT4 np_int_type = numpy.int8 if is_signed else numpy.uint8 qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 21a772c5f56c7..41dae04f1c6ff 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -36,7 +36,7 @@ def setUp(self): self.bias = np.array([0.0, 1.0], dtype=np.float32) self.default_act_qtype = onnx.TensorProto.UINT8 self.default_wgt_qtype = onnx.TensorProto.UINT8 - self.default_wgt_qtype_per_channel = onnx.TensorProto.INT8 + self.default_wgt_qtype_per_channel = onnx.TensorProto.UINT8 self.default_bias_qtype = onnx.TensorProto.INT32 self.default_zp_scales = { @@ -49,7 +49,8 @@ def setUp(self): self.default_zp_scales_per_channel = { "INP": (0, np.float32(0.0235294122248888)), "SIG_OUT": (0, np.float32(0.003911871928721666)), - "WGT": ([0, 0], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), + # per-channel weights are always symmetric (ie. zp = (qmin + qmax) / 2) + "WGT": ([127, 127], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), "BIAS": ([0, 0], [np.float32(0.00006160428165458143), np.float32(0.00004620321124093607)]), "OUT": (0, np.float32(0.005075461231172085)), } @@ -420,12 +421,17 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): - wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_range) + wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType( + wgt_zp.data_type, + symmetric=True, # per-channel is always symmetric + reduce_range=reduce_range, + ) expected_zp, expected_scale = compute_scale_zp( np.array(rmin_vals[index], dtype=np.float32), np.array(rmax_vals[index], dtype=np.float32), wgt_qmin, wgt_qmax, + symmetric=True, # per-channel is always symmetric ) self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale)) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index d8acb66158ed2..d922f153b4b91 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -72,6 +72,7 @@ class SdpaKernel(IntEnum): TRT_FLASH_ATTENTION = 32 TRT_CROSS_ATTENTION = 64 TRT_CAUSAL_ATTENTION = 128 + LEAN_ATTENTION = 256 # Since we support attention bias, so we only need support up to 2D mask. @@ -598,8 +599,8 @@ def measure_latency(cuda_session: CudaSession, input_dict): return end - start -def flops(batch, sequence_length, head_size, num_heads, causal): - return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1) +def flops(batch, sequence_length_q, sequence_length_kv, head_size, num_heads, causal): + return 4 * batch * sequence_length_q * sequence_length_kv * num_heads * head_size // (2 if causal else 1) def tflops_per_second(flop, time): @@ -613,6 +614,7 @@ def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: kernel_names = { SdpaKernel.DEFAULT: "ort:default", SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.LEAN_ATTENTION: "ort:lean", SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", SdpaKernel.MATH: "ort:math", @@ -808,16 +810,17 @@ def sdpa_kernel_from_debug_info( ): os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1" captured_text = None + try: with CaptureStdout() as captured: session = create_session(config, sess_options, attention_kernel=attention_kernel) input_dict = config.random_inputs() session.infer(input_dict) - captured_text = captured.output.decode() + captured_text = captured.output.decode() except Exception as e: print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}") - finally: - os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" + + os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" if captured_text is not None: m = re.search("SdpaKernel=(?P[A-Z_]+)", captured_text) @@ -825,6 +828,7 @@ def sdpa_kernel_from_debug_info( name = m.group("kernel") kernel_names = { "FLASH_ATTENTION": "ort:flash", + "LEAN_ATTENTION": "ort:lean", "EFFICIENT_ATTENTION": "ort:efficient", "CUDNN_FLASH_ATTENTION": "ort:cudnn", "MATH": "ort:math", @@ -867,6 +871,15 @@ def run_tflops_test( SdpaKernel.CUDNN_FLASH_ATTENTION, SdpaKernel.MATH, ] + + if args.past_sequence_length > 0: + backends.append(SdpaKernel.LEAN_ATTENTION) + + if args.past_sequence_length > 0 and causal: + backends.remove(SdpaKernel.CUDNN_FLASH_ATTENTION) + + if args.past_sequence_length > 4096: + backends.remove(SdpaKernel.MATH) else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: @@ -884,6 +897,8 @@ def run_tflops_test( for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + if past_sequence_length > 0 and input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + continue config = MultiHeadAttentionConfig( batch_size=batch_size, sequence_length=sequence_length, @@ -900,6 +915,7 @@ def run_tflops_test( dtype=torch.float16 if use_gpu else torch.float, share_past_present_buffer=False, input_format=input_format, + has_past_input=past_sequence_length > 0, has_attn_bias=args.has_attn_bias, broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1, @@ -926,11 +942,19 @@ def run_tflops_test( print(f"skip input_format for {vars(config)}") continue + if use_gpu and config.total_sequence_length > 8192: + if config.verbose: + print(f"skip large sequence length for {vars(config)}") + continue + if use_gpu: actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options) if actual_kernel is None: print(f"Warning: skip {config} since kernel from debug info is None") continue + if actual_kernel != request_kernel and request_kernel != "ort:default": + print(f"Skip since {actual_kernel=} != {request_kernel=}") + continue else: # CPU has no debug info for now. actual_kernel = request_kernel @@ -956,11 +980,17 @@ def run_tflops_test( format_str = InputFormats.input_format_str(input_format) # compute TFLOPS per second - speed = None - if past_sequence_length == 0: - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + average_latency, + ) row = { "use_gpu": use_gpu, @@ -983,11 +1013,11 @@ def run_tflops_test( } csv_writer.writerow(row) - speed = f"{speed:.2f}" if speed is not None else "NA" + speed = f"{speed:.3f}" if speed is not None else "NA" print( f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}" + f"{intra_op_num_threads}\t{average_latency * 1000:.3f}\t{speed}\t{actual_kernel}\t{request_kernel}" ) @@ -1055,7 +1085,17 @@ def run_torch_test( except RuntimeError: continue - speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + torch_latency, + ) input_format = "Q,K,V" print( f"{input_format}\t{causal}\t{False}\t{batch_size}\t" @@ -1090,7 +1130,8 @@ def run_tflops_tests(args): features += "_causal" if args.past_sequence_length > 0: features += "_past" - csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + csv_filename = "{}_{}_{}_{}.csv".format( + args.csv_filename_prefix, features, "torch" if args.torch else "ort", datetime.now().strftime("%Y%m%d-%H%M%S"), @@ -1343,6 +1384,14 @@ def _parse_arguments(): ) parser.set_defaults(broadcast_attn_bias_dim_1=False) + parser.add_argument( + "--csv_filename_prefix", + required=False, + type=str, + default="benchmark_mha", + help="Prefix of csv filename", + ) + args = parser.parse_args() return args diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index ff6dd16e698df..8d811219d4dac 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -5,45 +5,104 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" +# Usage: benchmark_mha.sh [gpu|cpu|lean] +task="${1:-gpu}" -export CUDA_VISIBLE_DEVICES=0 -python benchmark_mha.py --use_gpu +# Function to lock GPU clocks and set power limit for a GPU +configure_gpu() { + local gpu_id=$1 -echo "Benchmark BERT-Large performance on GPU without attention bias" -python benchmark_mha.py --use_gpu -b 16 + # Ensure nvidia-smi is available + if ! command -v nvidia-smi &> /dev/null + then + echo "nvidia-smi not found. Please ensure NVIDIA drivers are installed." + exit + fi -echo "Benchmark BERT-Large performance on GPU with attention bias" -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + # Enable Persistence Mode + sudo nvidia-smi -pm 1 -i $gpu_id -python benchmark_mha.py --use_gpu --use_cuda_graph -python benchmark_mha.py --use_gpu --torch + # Get the maximum clock speeds for graphics and memory. + nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" + max_graphics_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A1 "Max Clocks" | grep "Graphics" | awk '{print $3}') + max_memory_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" | grep "Memory" | awk '{print $3}') -cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + # Lock the GPU clocks to maximum frequencies + sudo nvidia-smi -i $gpu_id --lock-gpu-clocks=$max_graphics_clock,$max_graphics_clock + sudo nvidia-smi -i $gpu_id --lock-memory-clocks=$max_memory_clock,$max_memory_clock -echo "Benchmark performance on CPU with number of threads:" -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + nvidia-smi --query-gpu=clocks.gr,clocks.sm,clocks.mem --format=csv + echo "GPU $gpu_id clocks locked to $max_graphics_clock MHz (graphics) and $max_memory_clock MHz (memory)" -python benchmark_mha.py --intra_op_num_threads 1 -python benchmark_mha.py --intra_op_num_threads 2 -python benchmark_mha.py --intra_op_num_threads 4 -python benchmark_mha.py --intra_op_num_threads 8 + # Set Power Limit to maximum + power_limit=$(nvidia-smi --query-gpu=power.limit -i 0 --format=csv | grep "0" | awk '{print $1}') + power_limit=${power_limit%.*} + sudo nvidia-smi -pl $power_limit -i $gpu_id + export CUDA_VISIBLE_DEVICES=$gpu_id +} -echo "Benchmark performance on CPU with default threads settings:" -python benchmark_mha.py -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py -python benchmark_mha.py --torch +run_gpu_benchmarks() { + echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -python benchmark_mha.py --causal -python benchmark_mha.py --torch --causal + python benchmark_mha.py --use_gpu -# Pytorch SDPA does not support causal attention with past state, we only test ORT here. -python benchmark_mha.py --causal --has_past + echo "Benchmark BERT-Large performance on GPU without attention bias" + python benchmark_mha.py --use_gpu -b 16 -cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv + echo "Benchmark BERT-Large performance on GPU with attention bias" + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + + python benchmark_mha.py --use_gpu --use_cuda_graph + python benchmark_mha.py --use_gpu --torch + + cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv +} + +run_lean_benchmarks() { + echo "Benchmark long context decoding performance on GPU" + for b in 1 4 16; do + for s in 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536; do + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 16 -d 64 -r 1000 --csv_filename_prefix benchmark_lean + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 32 -d 128 -r 1000 --csv_filename_prefix benchmark_lean + done + done + cat benchmark_lean_*.csv > lean_benchmark_results.csv +} + +run_cpu_benchmarks() { + echo "Benchmark performance on CPU with number of threads:" + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + + python benchmark_mha.py --intra_op_num_threads 1 + python benchmark_mha.py --intra_op_num_threads 2 + python benchmark_mha.py --intra_op_num_threads 4 + python benchmark_mha.py --intra_op_num_threads 8 + + + echo "Benchmark performance on CPU with default threads settings:" + python benchmark_mha.py + ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py + python benchmark_mha.py --torch + + python benchmark_mha.py --causal + python benchmark_mha.py --torch --causal + + # Pytorch SDPA does not support causal attention with past state, we only test ORT here. + python benchmark_mha.py --causal --has_past + + cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv +} + +[ "$task" != "cpu" ] && configure_gpu 0 + +[ "$task" == "gpu" ] && run_gpu_benchmarks + +[ "$task" == "cpu" ] && run_cpu_benchmarks + +[ "$task" == "lean" ] && run_lean_benchmarks diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 46ab905977f48..a74d5389e9047 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -24,7 +24,7 @@ from parameterized import parameterized from test_gqa_cpu import smooth_softmax_ref -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers torch.manual_seed(0) @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff( def has_flash_attention(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, _ = torch.cuda.get_device_capability() return major >= 8 and ( platform.system() == "Linux" @@ -2009,6 +2011,8 @@ def has_flash_attention(): def has_memory_efficient(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, minor = torch.cuda.get_device_capability() if major < 5 or (major == 5 and minor < 3): return False @@ -2047,8 +2051,8 @@ def mha_test_cases(): (2048, 2048), ] ) - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [3] if pipeline_mode else [1, 6, 16] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for s, s2 in seqs: @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), (2000, 2000), - (200, 200), - (240, 240), ] if pipeline_mode else [ @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), (240, 240), ] if pipeline_mode @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases(): def gqa_past_memory_efficient_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 1024)] if pipeline_mode else [ (1, 128), @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases(): def gqa_past_flash_attention_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 2048)] if pipeline_mode else [ (1, 128), @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases(): def gqa_interactive_one_batch_flash_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(128, 2048)] if pipeline_mode else [ (1, 128), @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(32, 128)] if pipeline_mode else [ (1, 128), @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2326,120 +2322,114 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): ) -class TestGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (PROMPT CASE) --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (PROMPT CASE) --------") + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( + parity_check_gqa_past( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - parity_check_gqa_prompt_no_buff( + parity_check_gqa_past_no_buff( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=False, ) - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (TOKEN GEN) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( +@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + + parity_check_gqa_prompt( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, use_smooth_softmax=False, ) - parity_check_gqa_past_no_buff( + parity_check_gqa_prompt_no_buff( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2447,38 +2437,36 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) - @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) - def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (INTERACTIVE) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + @parameterized.expand(gqa_past_memory_efficient_test_cases()) + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") parity_check_gqa_past( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=False, ) @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - if not has_memory_efficient(): - return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 99460722c2469..a5910c28c2975 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -16,16 +16,16 @@ import onnxruntime -class TestGQA(unittest.TestCase): +@unittest.skipIf( + (not torch.cuda.is_available()) + or (platform.system() != "Linux") + or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), + reason="ROCm is not available, skipping tests.", +) +class TestRocmGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (PROMPT CASE) --------") parity_check_gqa_prompt( @@ -52,12 +52,6 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (TOKEN GEN) -------") parity_check_gqa_past( diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 08ec5de328b9d..77b4b326bf645 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1900,7 +1900,7 @@ class TestGQA(unittest.TestCase): def test_gqa_no_past(self): torch.manual_seed(69) print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ (127, 127), @@ -1916,8 +1916,8 @@ def test_gqa_no_past(self): (8000, 8000), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for sq, skv in seqs: for n, n2 in num_h: @@ -1954,9 +1954,9 @@ def test_gqa_no_past(self): def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [1] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 128)] if pipeline_mode else [ (1, 128), @@ -1972,8 +1972,8 @@ def test_gqa_past(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: @@ -2018,7 +2018,7 @@ def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(256, 2048)] if pipeline_mode else [ (1, 128), @@ -2034,8 +2034,8 @@ def test_gqa_interactive_one_batch(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 69f0035ef8a17..9e7c7378370c1 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -9,6 +9,7 @@ import concurrent.futures import itertools +import os import unittest from typing import Dict, List, Optional @@ -400,6 +401,49 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield config +def lean_attention_test_cases(provider: str, comprehensive: bool): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 80: + return + yield + + batch_sizes = [1, 2, 3] if comprehensive else [1, 2] + sequence_lengths = [2, 15, 16, 255, 256, 512, 1024, 2048, 4096, 8192] if comprehensive else [2, 255, 512] + heads = [1, 4, 16] if comprehensive else [1, 4] + head_sizes = [64, 128] + device, dtype, formats = get_provider_support_info(provider, True) + mask_formats = [AttentionMaskFormat.Mask_None] + + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory + for batch_size in batch_sizes: + for total_seq_len in sequence_lengths: + for num_heads in heads: + for head_size in head_sizes: + for format in formats: + for causal in get_causal_support(format): + for is_prompt in [False]: + for mask_format in mask_formats: + sequence_length = total_seq_len if is_prompt else 1 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=total_seq_len - sequence_length, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=True, + share_past_present_buffer=False, + input_format=format, + mask_format=mask_format, + ) + yield config + + def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return @@ -787,6 +831,12 @@ def run_mha_cuda(self): for config in mha_test_cases("CUDAExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) + def run_lean_attention(self): + os.environ["ORT_ENABLE_LEAN_ATTENTION"] = "1" + for config in lean_attention_test_cases("CUDAExecutionProvider", comprehensive_mode): + parity_check_mha(config, rtol=5e-3, atol=5e-3 if config.total_sequence_length <= 512 else 5e-2) + os.environ.pop("ORT_ENABLE_LEAN_ATTENTION", None) + def run_mha_cpu(self): for config in mha_test_cases("CPUExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) @@ -842,6 +892,7 @@ def test_all(self): # Run tests sequentially to avoid out of memory issue. self.run_mha_cpu() self.run_mha_cuda() + self.run_lean_attention() self.run_mha_cuda_multi_threading_default() self.run_mha_cuda_multi_threading_cudnn() self.run_mha_cuda_multi_threading_efficient() diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index c7db636a2f11f..058b1d2c9e0fa 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -5,30 +5,21 @@ # license information. # -------------------------------------------------------------------------- -# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG +# For live logging, use the following command: +# pytest -o log_cli=true --log-cli-level=DEBUG test_optimizer.py -import shutil import unittest -import pytest -import torch from model_loader import get_fusion_test_model, get_test_data_path from onnx import TensorProto, load_model from parity_utilities import find_transformers_source -from transformers import is_tf_available if find_transformers_source(): - from benchmark_helper import ConfigModifier, OptimizerInfo, Precision from fusion_options import FusionOptions - from huggingface_models import MODELS - from onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf from onnx_model import OnnxModel from optimizer import optimize_model else: - from onnxruntime.transformers.benchmark_helper import ConfigModifier, OptimizerInfo, Precision from onnxruntime.transformers.fusion_options import FusionOptions - from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf from onnxruntime.transformers.onnx_model import OnnxModel from onnxruntime.transformers.optimizer import optimize_model @@ -66,70 +57,6 @@ def verify_node_count(self, onnx_model, expected_node_count, test_name): self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count) - # test huggingface pytorch model - def _test_optimizer_on_huggingface_model( - self, - model_name, - expected_fusion_result_list, - inputs_count=1, - validate_model=True, - ): - # Remove cached model so that CI machine has enough space. Do not remove cache models in dev machine. - if not find_transformers_source(): - shutil.rmtree("./cache_models", ignore_errors=True) - shutil.rmtree("./onnx_models", ignore_errors=True) - - # expect fusion result list have the following keys - # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization - model_fusion_statistics = {} - - input_names = MODELS[model_name][0] - - config_modifier = ConfigModifier(None) - fusion_options = None - model_class = "AutoModel" - with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_pt( - model_name, - MODELS[model_name][1], # opset version - MODELS[model_name][2], # use_external_data_format - MODELS[model_name][3], # optimization model type - model_class, - config_modifier, - "./cache_models", - "./onnx_models", - input_names[:inputs_count], - False, - Precision.FLOAT32, - OptimizerInfo.BYSCRIPT, - True, - True, - True, - model_fusion_statistics, - fusion_options, - ) - - if validate_model: - self.assertEqual(is_valid_onnx_model, True) - - expected_node_count = { - "EmbedLayerNormalization": expected_fusion_result_list[0], - "Attention": expected_fusion_result_list[1], - "Gelu": expected_fusion_result_list[2], - "FastGelu": expected_fusion_result_list[3], - "BiasGelu": expected_fusion_result_list[4], - "LayerNormalization": expected_fusion_result_list[5], - "SkipLayerNormalization": expected_fusion_result_list[6], - } - - for value in model_fusion_statistics.values(): - actual_node_count = value - - for op_type, count in expected_node_count.items(): - if op_type not in actual_node_count or actual_node_count[op_type] != count: - print(f"expected: {expected_node_count} got {actual_node_count}") - self.assertTrue(False) - def test_gpt2_past(self): for enable_skip_layer_norm_fusion in [False, True]: input_path = _get_test_model_path("gpt2_past") @@ -227,176 +154,6 @@ def test_embed_layer_norm_fusion(self): } self.verify_node_count(model, expected_node_count, file) - @pytest.mark.slow - def test_huggingface_bert_fusion_1(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1) - - @pytest.mark.slow - def test_huggingface_bert_fusion_2(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2) - - @pytest.mark.slow - def test_huggingface_bert_fusion_3(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3) - - @pytest.mark.slow - def test_huggingface_openaigpt_fusion(self): - self._test_optimizer_on_huggingface_model("openai-gpt", [0, 12, 0, 12, 0, 0, 24]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of gpt-2 on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_gpt2_fusion(self): - self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of xlm on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_xlm_fusion(self): - self._test_optimizer_on_huggingface_model("xlm-mlm-ende-1024", [0, 6, 0, 0, 6, 0, 13]) - - @pytest.mark.slow - def test_huggingface_roberta_fusion(self): - self._test_optimizer_on_huggingface_model("roberta-base", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - def test_huggingface_distillbert_fusion(self): - self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=1) - self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=2) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of camembert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_camembert_fusion(self): - self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of albert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_albert_fusion(self): - self._test_optimizer_on_huggingface_model("albert-base-v1", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - @unittest.skip("skip fusion test of t5 since it is not implemented yet") - def test_huggingface_t5_fusion(self): - self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0]) - - @pytest.mark.slow - def test_huggingface_xlmroberta_fusion(self): - self._test_optimizer_on_huggingface_model("xlm-roberta-base", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of flaubert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_flaubert_fusion(self): - self._test_optimizer_on_huggingface_model( - "flaubert/flaubert_base_cased", - [0, 12, 0, 0, 12, 0, 25], - validate_model=False, - ) - self._test_optimizer_on_huggingface_model( - "flaubert/flaubert_small_cased", - [0, 6, 0, 0, 6, 12, 1], - validate_model=False, - ) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of dialogpt on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_dialogpt_fusion(self): - self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0]) - - @pytest.mark.slow - def test_huggingface_bart_fusion(self): - self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) - - @pytest.mark.slow - def test_huggingface_vit_fusion(self): - self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24]) - - -@unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available") -class TestTensorflowModelOptimization(unittest.TestCase): - def setUp(self): - try: - import tf2onnx # noqa: F401 - except ImportError: - self.skipTest("skip TestBertOptimizationTF since tf2onnx not installed") - - def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True): - # Remove cached model so that CI machine has enough space. Do not remove cache models in dev machine. - if not find_transformers_source(): - shutil.rmtree("./cache_models", ignore_errors=True) - shutil.rmtree("./onnx_models", ignore_errors=True) - - # expect fusion result list have the following keys - # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization - model_fusion_statistics = {} - print("testing mode ", model_name) - print("testing input number = ", inputs_count) - input_names = MODELS[model_name][0] - - config_modifier = ConfigModifier(None) - fusion_options = None - model_class = "AutoModel" - with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_tf( - model_name, - MODELS[model_name][1], # opset version - MODELS[model_name][2], # use_external_data_format - MODELS[model_name][3], # optimization model - model_class, - config_modifier, - "./cache_models", - "./onnx_models", - input_names[:inputs_count], - False, - Precision.FLOAT32, - True, - True, - True, - True, - model_fusion_statistics, - fusion_options, - ) - - onnx_model = next(iter(model_fusion_statistics.keys())) - fusion_result_list = list(model_fusion_statistics[onnx_model].values()) - - if validate_model: - self.assertEqual(is_valid_onnx_model, True) - self.assertEqual(fusion_result_list, expected_fusion_result_list) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_1(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_2(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_3(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3) - - @pytest.mark.slow - def test_huggingface_distilgpt2_from_tf2onnx(self): - self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1) - - @pytest.mark.slow - def test_huggingface_albert_from_tf2onnx(self): - self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1) - - @pytest.mark.slow - def test_huggingface_gpt2_from_tf2onnx(self): - self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_roberta_from_tf2onnx(self): - self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_distilbert_from_tf2onnx(self): - self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_xlm_from_tf2onnx(self): - self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False) - if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py b/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py new file mode 100644 index 0000000000000..e4f883dc8b45c --- /dev/null +++ b/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# For live logging, use the following command: +# pytest -o log_cli=true --log-cli-level=DEBUG test_optimizer_huggingface_bert.py + +import shutil +import unittest +from pathlib import Path + +import torch +from parity_utilities import find_transformers_source +from transformers.utils import default_cache_path + +if find_transformers_source(): + from benchmark_helper import ConfigModifier, OptimizerInfo, Precision + from compare_bert_results import run_test as bert_parity_test + from onnx_exporter import export_onnx_model_from_pt +else: + from onnxruntime.transformers.benchmark_helper import ConfigModifier, OptimizerInfo, Precision + from onnxruntime.transformers.compare_bert_results import run_test as bert_parity_test + from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt + + +class TestHuggingfaceBertModelOptimization(unittest.TestCase): + def run_optimizer_on_model( + self, + model_name, + expected_fusion_result_list, + inputs_count=1, + validate_model=True, + opset_version=16, + use_external_data_format=False, + model_type="bert", + ): + onnx_dir = Path(".") / "onnx_models" / model_name + shutil.rmtree(onnx_dir, ignore_errors=True) + + Path(onnx_dir).mkdir(parents=True, exist_ok=True) + + model_fusion_statistics = {} + + input_names = ["input_ids", "attention_mask", "token_type_ids"] + + config_modifier = ConfigModifier(None) + fusion_options = None + model_class = "AutoModel" + with torch.no_grad(): + optimized_model_path, is_valid_onnx_model, _, _ = export_onnx_model_from_pt( + model_name=model_name, + opset_version=opset_version, + use_external_data_format=use_external_data_format, + model_type=model_type, + model_class=model_class, + config_modifier=config_modifier, + cache_dir=default_cache_path, + onnx_dir=str(onnx_dir), + input_names=input_names[:inputs_count], + use_gpu=False, + precision=Precision.FLOAT32, + optimizer_info=OptimizerInfo.BYSCRIPT, + validate_onnx=True, + use_raw_attention_mask=True, + overwrite=True, + model_fusion_statistics=model_fusion_statistics, + fusion_options=fusion_options, + ) + + if validate_model: + self.assertEqual(is_valid_onnx_model, True) + + expected_node_count = { + "EmbedLayerNormalization": expected_fusion_result_list[0], + "Attention": expected_fusion_result_list[1], + "Gelu": expected_fusion_result_list[2], + "FastGelu": expected_fusion_result_list[3], + "BiasGelu": expected_fusion_result_list[4], + "LayerNormalization": expected_fusion_result_list[5], + "SkipLayerNormalization": expected_fusion_result_list[6], + } + + node_count = None + for value in model_fusion_statistics.values(): + node_count = value + self.assertIsNotNone(node_count) + + actual_node_count = {} + for op_type in expected_node_count: + actual_node_count[op_type] = node_count.get(op_type, 0) + + expected = ", ".join(f"{key}: {value}" for key, value in sorted(expected_node_count.items())) + actual = ", ".join(f"{key}: {value}" for key, value in sorted(actual_node_count.items())) + self.assertEqual(expected, actual) + + suffix = "_fp32_cpu.onnx" + assert optimized_model_path.endswith(suffix) + baseline_model_path = optimized_model_path[: -len(suffix)] + ".onnx" + for batch_size in [1, 2]: + for sequence_length in [1, 8]: + max_abs_diff, case_passed = bert_parity_test( + baseline_model_path, + optimized_model_path, + output_dir=None, + batch_size=batch_size, + sequence_length=sequence_length, + use_gpu=False, + test_cases=1, + seed=123, + verbose=False, + rtol=1e-4, + atol=1e-4, + input_ids_name=input_names[0], + segment_ids_name=input_names[2] if inputs_count > 2 else None, + input_mask_name=input_names[1] if inputs_count > 1 else None, + mask_type=2, + dictionary_size=1024, + ) + self.assertTrue( + case_passed, f"bert parity test failed: {batch_size=} {sequence_length=} {max_abs_diff=}" + ) + + def test_bert(self): + model_name = "hf-internal-testing/tiny-random-bert" + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=2) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=3) + + def test_roberta(self): + model_name = "hf-internal-testing/tiny-random-roberta" + # TODO: EmbedLayerNormalization fusion. + self.run_optimizer_on_model(model_name, [0, 5, 0, 0, 5, 1, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [0, 5, 0, 0, 5, 1, 10], inputs_count=2) + + def test_distillbert(self): + model_name = "hf-internal-testing/tiny-random-distilbert" + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=2) + + def test_xlm_roberta(self): + model_name = "hf-internal-testing/tiny-xlm-roberta" + # TODO: EmbedLayerNormalization fusion. + self.run_optimizer_on_model(model_name, [0, 2, 0, 0, 2, 1, 4], inputs_count=1) + self.run_optimizer_on_model(model_name, [0, 2, 0, 0, 2, 1, 4], inputs_count=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index 1e7940e38335f..baaaeaa766db9 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -651,7 +651,6 @@ def parity_check(self): torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) if ort_output is not None: - assert torch.allclose(torch_output, ort_output.to(torch.float32), rtol=THRESHOLD, atol=THRESHOLD) print( "name:", self.__class__.__name__, @@ -661,8 +660,8 @@ def parity_check(self): self.sequence_length, " max_diff:", (torch_output - ort_output).abs().max(), - " parity: OK", ) + torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) @@ -996,6 +995,13 @@ def small_test_cases(): yield batch_size, sequence_length +def phi3_test_cases(): + # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. + for batch_size in [1, 4, 16]: + for sequence_length in [128]: + yield batch_size, sequence_length + + class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_switch_moe_parity(self, batch_size, sequence_length): @@ -1023,7 +1029,7 @@ def test_mixtral_moe_parity(self, batch_size, sequence_length): class TestPhiMoE(unittest.TestCase): - @parameterized.expand(small_test_cases()) + @parameterized.expand(phi3_test_cases()) def test_phi3_moe_parity(self, batch_size, sequence_length): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py index e4f65b07c552e..84708ddcf82a5 100644 --- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py +++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py @@ -159,6 +159,7 @@ def create_t5_decoder_masked_mha_graph( head_size, num_heads, is_cross_attention, + beam_width=1, ): nodes = [ helper.make_node( @@ -172,6 +173,8 @@ def create_t5_decoder_masked_mha_graph( "past_key" if not is_cross_attention else "", "past_value" if not is_cross_attention else "", "past_sequence_length" if not is_cross_attention else "", + "beam_width" if beam_width > 1 else "", + "cache_indirection" if beam_width > 1 else "", ], [ "output", @@ -233,6 +236,15 @@ def create_t5_decoder_masked_mha_graph( ) ) graph_inputs.append(helper.make_tensor_value_info("past_sequence_length", TensorProto.INT32, [1])) + + if beam_width > 1: + graph_inputs.append(helper.make_tensor_value_info("beam_width", TensorProto.INT32, [1])) + graph_inputs.append( + helper.make_tensor_value_info( + "cache_indirection", TensorProto.INT32, [batch_size, beam_width, past_sequence_length + 1] + ) + ) + graph_outputs.append( helper.make_tensor_value_info( "present_key", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size] @@ -275,7 +287,7 @@ def __init__(self, is_decoder, batch_size, seq_len, kv_sequence_length, num_head class T5Attention(nn.Module): - def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bool = False): + def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bool = False, beam_width=1): super().__init__() self.is_decoder = config.is_decoder self.is_static_kv = is_static_kv @@ -284,6 +296,7 @@ def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bo self.key_value_proj_dim = config.head_size self.n_heads = config.num_heads self.inner_dim = self.n_heads * self.key_value_proj_dim + self.beam_width = beam_width # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -312,6 +325,7 @@ def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bo self.head_size, self.num_heads, is_static_kv, + beam_width=self.beam_width, ) else: self.onnx_graph = create_t5_mha_graph( @@ -371,7 +385,17 @@ def create_inputs(self): position_bias = torch.normal(mean=5, std=0.1, size=(1, self.num_heads, 1, position_bias_length)).to( torch.float32 ) - return hidden_states, key_value_states, past_key_value, attention_mask, position_bias + + inputs = [hidden_states, key_value_states, past_key_value, attention_mask, position_bias] + if self.beam_width > 1: + # Treat total_sequence_length as max_sequence_length here. + max_sequence_length = self.kv_sequence_length + self.seq_len + # Use random generated values here, which may not be valid in real case. + cache_indirection = torch.randint( + 0, self.beam_width, (self.batch_size, self.beam_width, max_sequence_length) + ).to(torch.int32) + inputs.append(cache_indirection) + return inputs def torch_forward( self, @@ -497,16 +521,17 @@ def ort_forward( past_key_value=None, mask=None, position_bias=None, - use_cache=False, + use_cuda=True, query_length=None, + cache_indirection=None, ): import onnxruntime sess_options = onnxruntime.SessionOptions() - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): - return None - ort_session = onnxruntime.InferenceSession(self.onnx_graph, sess_options, providers=cuda_providers) + execution_providers = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] + if execution_providers[0] not in onnxruntime.get_available_providers(): + return + ort_session = onnxruntime.InferenceSession(self.onnx_graph, sess_options, providers=execution_providers) batch_size, seq_length = hidden_states.shape[:2] @@ -578,13 +603,17 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): torch_past_value_padded[:, :, : torch_past_value.shape[2], :] = torch_past_value if self.is_static_kv: if self.use_decoder_masked_kernel: - reordered_past_key = self.reorder_key_cache( - torch_past_key.flatten(), - batch_size=batch_size, - num_heads=self.num_heads, - sequence_length=self.kv_sequence_length, - head_size=self.head_size, - max_sequence_length=self.kv_sequence_length, + reordered_past_key = ( + self.reorder_key_cache( + torch_past_key.flatten(), + batch_size=batch_size, + num_heads=self.num_heads, + sequence_length=self.kv_sequence_length, + head_size=self.head_size, + max_sequence_length=self.kv_sequence_length, + ) + if use_cuda + else torch_past_key ) ort_inputs["key"] = reordered_past_key.reshape(torch_past_key.shape) ort_inputs["value"] = torch_past_value @@ -595,13 +624,17 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): ort_inputs["key"] = np.ascontiguousarray(key_states.detach().numpy()) ort_inputs["value"] = np.ascontiguousarray(value_states.detach().numpy()) if self.use_decoder_masked_kernel: - reordered_past_key = self.reorder_key_cache( - torch_past_key_padded.flatten(), - batch_size=batch_size, - num_heads=self.num_heads, - sequence_length=self.kv_sequence_length, - head_size=self.head_size, - max_sequence_length=max_seq_len, + reordered_past_key = ( + self.reorder_key_cache( + torch_past_key_padded.flatten(), + batch_size=batch_size, + num_heads=self.num_heads, + sequence_length=self.kv_sequence_length, + head_size=self.head_size, + max_sequence_length=max_seq_len, + ) + if use_cuda + else torch_past_key_padded ) ort_inputs["past_key"] = reordered_past_key.reshape(torch_past_value_padded.shape) ort_inputs["past_value"] = torch_past_value_padded @@ -617,6 +650,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if torch_position_bias is not None: ort_inputs["attention_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) + if self.beam_width > 1: + ort_inputs["beam_width"] = np.ascontiguousarray(np.array([self.beam_width], dtype=np.int32)) + ort_inputs["cache_indirection"] = np.ascontiguousarray(cache_indirection.detach().numpy()) + ort_output = ort_session.run(None, ort_inputs) output = None @@ -628,7 +665,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return output -def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False): +def compare_t5_cross_attention_decoder( + batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False, use_cuda=True +): config = T5Config( is_decoder=True, batch_size=batch_size, @@ -646,7 +685,7 @@ def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size hidden_states, key_value_states, past_key_value, attention_mask, position_bias=None, use_cache=False ) ort_output = T5CrossAttention.ort_forward( - hidden_states, key_value_states, past_key_value, attention_mask, position_bias=None, use_cache=False + hidden_states, key_value_states, past_key_value, attention_mask, position_bias=None, use_cuda=use_cuda ) if ort_output is not None: @@ -669,9 +708,7 @@ def compare_t5_cross_attention_decoder_init(batch_size, seq_len, num_heads, head torch_output = T5CrossAttention.torch_forward( hidden_states, key_value_states, None, attention_mask, position_bias=None, use_cache=True ) - ort_output = T5CrossAttention.ort_forward( - hidden_states, key_value_states, None, attention_mask, position_bias=None, use_cache=True - ) + ort_output = T5CrossAttention.ort_forward(hidden_states, key_value_states, None, attention_mask, position_bias=None) if ort_output is not None: assert torch.allclose(torch_output[0], ort_output[0], atol=1e-4) @@ -695,9 +732,7 @@ def compare_t5_self_attention_decoder_init(batch_size, seq_len, num_heads, head_ torch_output = T5CrossAttention.torch_forward( hidden_states, None, None, mask=None, position_bias=position_bias, use_cache=True ) - ort_output = T5CrossAttention.ort_forward( - hidden_states, None, None, mask=None, position_bias=position_bias, use_cache=True - ) + ort_output = T5CrossAttention.ort_forward(hidden_states, None, None, mask=None, position_bias=position_bias) if ort_output is not None: assert torch.allclose(torch_output[0], ort_output[0], atol=1e-4) @@ -705,7 +740,9 @@ def compare_t5_self_attention_decoder_init(batch_size, seq_len, num_heads, head_ assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4) -def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False): +def compare_t5_self_attention_decoder( + batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False, use_cuda=True, beam_width=1 +): config = T5Config( is_decoder=True, batch_size=batch_size, @@ -716,21 +753,45 @@ def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, use_past=True, ) - T5CrossAttention = T5Attention(config, is_static_kv=False, use_decoder_masked_kernel=use_dmmha) # noqa: N806 - - hidden_states, _, past_key_value, _, position_bias = T5CrossAttention.create_inputs() - torch_output = T5CrossAttention.torch_forward( - hidden_states, None, past_key_value, mask=None, position_bias=position_bias, use_cache=True + T5CrossAttention = T5Attention( # noqa: N806 + config, is_static_kv=False, use_decoder_masked_kernel=use_dmmha, beam_width=beam_width ) + + hidden_states, _, past_key_value, _, position_bias, *maybe_cache_indirection = T5CrossAttention.create_inputs() + cache_indirection = maybe_cache_indirection[0] if beam_width > 1 else None + if beam_width > 1: + # When beam_width > 1, use ORT CUDA result as reference + ref_output = T5CrossAttention.ort_forward( + hidden_states, + None, + past_key_value, + mask=None, + position_bias=position_bias, + cache_indirection=cache_indirection, + use_cuda=True, + ) + if ref_output is None: + # Return directly if CUDA EP is not available + return + else: + ref_output = T5CrossAttention.torch_forward( + hidden_states, None, past_key_value, mask=None, position_bias=position_bias, use_cache=True + ) ort_output = T5CrossAttention.ort_forward( - hidden_states, None, past_key_value, mask=None, position_bias=position_bias, use_cache=True + hidden_states, + None, + past_key_value, + mask=None, + position_bias=position_bias, + cache_indirection=cache_indirection, + use_cuda=use_cuda, ) if ort_output is not None: - assert torch.allclose(torch_output[0], ort_output[0], atol=1e-4) + assert torch.allclose(ref_output[0], ort_output[0], atol=1e-4) if not use_dmmha: - assert torch.allclose(torch_output[1][0], ort_output[1][0], atol=1e-4) - assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4) + assert torch.allclose(ref_output[1][0], ort_output[1][0], atol=1e-4) + assert torch.allclose(ref_output[1][1], ort_output[1][1], atol=1e-4) class TestT5MHAParity(unittest.TestCase): @@ -761,23 +822,53 @@ def test_t5_self_attention_decoder(self): self.batch_size, self.seq_len, self.num_heads, self.head_size, self.kv_sequence_length ) - def test_t5_cross_attention_decoder_masked_mha(self): + def test_t5_cross_attention_decoder_masked_mha(self, use_cuda=True): batch_size = 2 seq_len = 1 num_heads = 2 head_size = 32 kv_sequence_length = 2 compare_t5_cross_attention_decoder( - batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True + batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True, use_cuda=use_cuda ) - def test_t5_self_attention_decoder_masked_mha(self): + def test_t5_self_attention_decoder_masked_mha(self, use_cuda=True): batch_size = 2 seq_len = 1 num_heads = 2 head_size = 32 kv_sequence_length = 2 - compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True) + compare_t5_self_attention_decoder( + batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True, use_cuda=use_cuda + ) + + def test_t5_cross_attention_decoder_masked_mha_cpu(self): + return self.test_t5_cross_attention_decoder_masked_mha(use_cuda=False) + + def test_t5_self_attention_decoder_masked_mha_cpu(self): + return self.test_t5_self_attention_decoder_masked_mha(use_cuda=False) + + def test_t5_self_attention_decoder_masked_mha_with_beams(self): + """ + Test DecoderMaskedMultiHeadAttention self-attention case with beam_width > 1. + Compare the results on CUDA and CPU EPs. + """ + batch_size = 4 + seq_len = 1 + num_heads = 2 + head_size = 32 + kv_sequence_length = 2 + beam_width = 2 + compare_t5_self_attention_decoder( + batch_size, + seq_len, + num_heads, + head_size, + kv_sequence_length, + use_dmmha=True, + use_cuda=False, + beam_width=beam_width, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 102846e08ac5f..24c343c7b9541 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -48,6 +48,9 @@ namespace qnnctxgen { "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -143,7 +146,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing") { + } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" || + key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -154,7 +158,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode', - 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing'])"); + 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing', + 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } test_config.run_config.qnn_options[key] = value; diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc index d568d5e78688a..3be0bd253c8a4 100644 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ b/onnxruntime/test/qnn_ctx_gen/main.cc @@ -33,8 +33,11 @@ static void CheckStatus(const Status& status) { // from the last context cache Onnx model, find the EPContext node with main_context=1, // and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file) { + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; std::shared_ptr ctx_model; CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); @@ -43,6 +46,7 @@ static void GetLastContextBinaryFileName(const std::basic_string last if (node.OpType() == "EPContext") { NodeAttrHelper node_helper(node); int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); + max_size = node_helper.Get("max_size", static_cast(0)); if (1 == is_main_context) { last_ctx_bin_file = node_helper.Get("ep_cache_context", ""); return; @@ -55,7 +59,8 @@ static void GetLastContextBinaryFileName(const std::basic_string last // the last QNN context binary file // Remove not used QNN context binary file, only keep the last one which contains all graphs static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name) { + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { for (auto ep_ctx_file : ep_ctx_files) { std::shared_ptr ctx_model; auto path_str = ToPathString(ep_ctx_file); @@ -75,6 +80,8 @@ static void UpdateEpContextModel(const std::vector> std::remove(file_path.string().c_str()); node.ClearAttribute("ep_cache_context"); node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); + node.ClearAttribute("max_size"); + node.AddAttribute("max_size", max_size); } } } @@ -181,7 +188,8 @@ int real_main(int argc, char* argv[]) { // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name); + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; if (last_qnn_ctx_binary_file_name.empty()) { throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); @@ -191,7 +199,7 @@ int real_main(int argc, char* argv[]) { // Update generated context cache Onnx model to make the main EPContext node point to // the last QNN context binary file // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name); + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); } ORT_CATCH(const Ort::Exception& e) { fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 5816539f9194f..e8c8c8db8d08f 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -335,6 +335,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod #endif } else if (provider_type == 3) { #ifdef USE_ROCM + std::cout << "Running simple inference with rocm provider" << std::endl; OrtROCMProviderOptions rocm_options; session_options.AppendExecutionProvider_ROCM(rocm_options); #else @@ -384,7 +385,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); -#if defined(USE_CUDA) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx"); #endif static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx"); @@ -2341,7 +2342,7 @@ TEST(CApiTest, basic_cuda_graph) { #endif } -#if defined(USE_CUDA) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) struct CudaGraphInputOutputData_0 { const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -2385,6 +2386,12 @@ static void RunWithCudaGraphAnnotation(T& cg_data, Ort::MemoryAllocation& input_data, Ort::MemoryAllocation& output_data, const char* cuda_graph_annotation) { +// a local hipify of select cuda symbols to avoid code duplication +#ifdef USE_ROCM +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#endif #ifdef USE_DML Ort::SessionOptions session_options; Ort::Allocator allocator(session, info_mem); @@ -2488,6 +2495,11 @@ static void RunWithCudaGraphAnnotation(T& cg_data, // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#ifdef USE_ROCM +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } TEST(CApiTest, basic_cuda_graph_with_annotation) { @@ -2502,7 +2514,7 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { ort_dml_api->SessionOptionsAppendExecutionProvider_DML1(session_options, dml_objects.dml_device.Get(), dml_objects.command_queue.Get()); Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -2516,6 +2528,20 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { static_cast(session_options), rel_cuda_options.get()) == nullptr); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #endif Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options); diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index 8171a6eecc91d..ba16bd6c9888f 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -987,6 +987,32 @@ TEST(CApiTest, SparseTensorFillSparseTensorFormatAPI) { } } +TEST(CApi, TestResize) { + std::vector values; + values.resize(10); + + std::vector sts; + sts.resize(5); + + std::vector domains; + domains.resize(5); + + std::vector type_and_shape; + type_and_shape.resize(5); + + std::vector seq_type_info; + seq_type_info.resize(5); + + std::vector map_type_info; + map_type_info.resize(5); + + std::vector type_info; + type_info.resize(5); + + std::vector op_attr; + op_attr.resize(5); +} + TEST(CApiTest, SparseTensorFillSparseFormatStringsAPI) { auto allocator = Ort::AllocatorWithDefaultOptions(); Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc index 069246b4201e7..807182ee28946 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc @@ -28,7 +28,7 @@ void KernelOne(const Ort::Custom::RocmContext& rocm_ctx, auto input_shape = X.Shape(); CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); CUSTOM_ENFORCE(rocm_ctx.miopen_handle, "failed to fetch miopen handle"); - CUSTOM_ENFORCE(rocm_ctx.rblas_handle, "failed to fetch rocblas handle"); + CUSTOM_ENFORCE(rocm_ctx.blas_handle, "failed to fetch rocblas handle"); auto z_raw = Z.Allocate(input_shape); rocm_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), rocm_ctx.hip_stream); } @@ -40,4 +40,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) { } // namespace Rocm -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/testdata/dummy_t5.onnx b/onnxruntime/test/testdata/dummy_t5.onnx new file mode 100644 index 0000000000000..3a3bbf4767523 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx new file mode 100644 index 0000000000000..4b36cc9b6eca0 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000..5a5c302914890 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx differ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 6681f7a93971e..3d07bfcce101c 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -306,6 +306,11 @@ "^test_qlinearmatmul_3D_int8_float32_cuda", "^test_qlinearmatmul_3D_uint8_float16_cuda", "^test_qlinearmatmul_3D_uint8_float32_cuda", + // Tests that failed on CUDA 12.2. + "^test_Conv3d_dilated_cuda", + "^test_Conv3d_dilated_strided_cuda", + "^test_Conv3d_stride_cuda", + "^test_Conv3d_stride_padding_cuda", // Size(21) from ONNX 1.16.0 is not implemented in cuda. "^test_size_cuda", "^test_size_example_cuda", @@ -712,6 +717,30 @@ "^test_nonmaxsuppression_flipped_coordinates_cpu", "^test_nonmaxsuppression_center_point_box_format_cpu" ], + "current_failing_tests_WEBGPU": [ + "^test_layer_normalization_2d_axis0_cpu", + "^test_layer_normalization_2d_axis1_cpu", + "^test_layer_normalization_2d_axis_negative_1_cpu", + "^test_layer_normalization_2d_axis_negative_2_cpu", + "^test_layer_normalization_3d_axis0_epsilon_cpu", + "^test_layer_normalization_3d_axis1_epsilon_cpu", + "^test_layer_normalization_3d_axis2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_1_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_3_epsilon_cpu", + "^test_layer_normalization_4d_axis0_cpu", + "^test_layer_normalization_4d_axis1_cpu", + "^test_layer_normalization_4d_axis2_cpu", + "^test_layer_normalization_4d_axis3_cpu", + "^test_layer_normalization_4d_axis_negative_1_cpu", + "^test_layer_normalization_4d_axis_negative_2_cpu", + "^test_layer_normalization_4d_axis_negative_3_cpu", + "^test_layer_normalization_4d_axis_negative_4_cpu", + "^test_layer_normalization_default_axis_cpu", + "^test_gelu_tanh_1_expanded_cpu", + "^test_gelu_tanh_2_expanded_cpu", + "^test_dynamicquantizelinear_expanded_cpu" + ], "current_failing_tests_pure_DML": [ "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu", "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", diff --git a/onnxruntime/test/testdata/relu_with_optional.onnx b/onnxruntime/test/testdata/relu_with_optional.onnx new file mode 100644 index 0000000000000..b52c6927527bd Binary files /dev/null and b/onnxruntime/test/testdata/relu_with_optional.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx new file mode 100644 index 0000000000000..8ca8282572db8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx new file mode 100644 index 0000000000000..2521a89b7bb56 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx @@ -0,0 +1,41 @@ + : +R +inputA a_quantizeda_scalea_zpDynamicQuantizeLinear"DynamicQuantizeLinear +Y + a_quantized +inputB +a_zp +inputBZPmatmulinteger_output MatMulInteger" MatMulInteger +- +a_scale + inputBScalemul_1 mul_right"Mul +: +matmulinteger_output cast_outputcast"Cast* +to +- +mul_1 + cast_outputoutput +mul_bottom"Mul+matmul_integer_to_float_large_tensor_fusionZ" +inputA + + + + + + +Z +inputB + + + + + +Z +inputBZP + + +Z + inputBScale + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py new file mode 100644 index 0000000000000..543517cc015ef --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py @@ -0,0 +1,49 @@ +from enum import Enum # noqa: F401 + +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name): # noqa: N802 + inputs = [] + outputs = [] + initializers = [] + nodes = [] + + inputs.append(helper.make_tensor_value_info("inputA", TensorProto.FLOAT, [16, 32, 1280, 1280])) + inputs.append(helper.make_tensor_value_info("inputB", TensorProto.INT8, [1280, 1280])) + inputs.append(helper.make_tensor_value_info("inputBZP", TensorProto.INT8, [1])) + inputs.append(helper.make_tensor_value_info("inputBScale", TensorProto.FLOAT, [1])) + + nodes = [ # construct graph + helper.make_node( + "DynamicQuantizeLinear", + ["inputA"], + ["a_quantized", "a_scale", "a_zp"], + "DynamicQuantizeLinear", + ), + helper.make_node( + "MatMulInteger", + ["a_quantized", "inputB", "a_zp", "inputBZP"], + ["matmulinteger_output"], + "MatMulInteger", + ), + helper.make_node("Mul", ["a_scale", "inputBScale"], ["mul_1"], "mul_right"), + helper.make_node("Cast", ["matmulinteger_output"], ["cast_output"], "cast", to=1), + helper.make_node("Mul", ["mul_1", "cast_output"], ["output"], "mul_bottom"), + ] + + graph = helper.make_graph( + nodes, + "matmul_integer_to_float_large_tensor_fusion", # name + inputs, + outputs, + initializers, + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + GenerateModel("matmul_integer_to_float_large_tensor.onnx") diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 1d89272680e47..b558a7f00f7bc 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -27,8 +27,8 @@ std::unique_ptr ort_env; -// ortenv_setup is used by /onnxruntime/test/xctest/xcgtest.mm so can't be file local -void ortenv_setup() { +// ortenv_setup() and ortenv_teardown() are used by onnxruntime/test/xctest/xcgtest.mm so can't be file local +extern "C" void ortenv_setup() { OrtThreadingOptions tpo; // allow verbose logging to be enabled by setting this environment variable to a numeric log level @@ -46,6 +46,10 @@ void ortenv_setup() { ort_env.reset(new Ort::Env(&tpo, log_level, "Default")); } +extern "C" void ortenv_teardown() { + ort_env.reset(); +} + #ifdef USE_TENSORRT #if defined(_MSC_VER) @@ -101,7 +105,7 @@ int TEST_MAIN(int argc, char** argv) { } // TODO: Fix the C API issue - ort_env.reset(); // If we don't do this, it will crash + ortenv_teardown(); // If we don't do this, it will crash #ifndef USE_ONNXRUNTIME_DLL // make memory leak checker happy diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index d07e01c1a4e27..c1564997c42b8 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -99,11 +99,13 @@ std::unique_ptr MIGraphXExecutionProviderWithOptions(const O return nullptr; } -std::unique_ptr OpenVINOExecutionProviderWithOptions(const OrtOpenVINOProviderOptions* params) { +std::unique_ptr OpenVINOExecutionProviderWithOptions(const ProviderOptions* params, + const SessionOptions* session_options) { #ifdef USE_OPENVINO - return OpenVINOProviderFactoryCreator::Create(params)->CreateProvider(); + return OpenVINOProviderFactoryCreator::Create(params, session_options)->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); + ORT_UNUSED_PARAMETER(session_options); return nullptr; #endif } @@ -174,14 +176,6 @@ std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDn return nullptr; } -// std::unique_ptr DefaultTvmExecutionProvider() { -// #ifdef USE_TVM -// return TVMProviderFactoryCreator::Create("")->CreateProvider(); -// #else -// return nullptr; -// #endif -// } - std::unique_ptr DefaultNnapiExecutionProvider() { // The NNAPI EP uses a stub implementation on non-Android platforms so cannot be used to execute a model. // Manually append an NNAPI EP instance to the session to unit test the GetCapability and Compile implementation. @@ -245,14 +239,14 @@ std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlpr // The test will create a model but execution of it will obviously fail. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision - uint32_t coreml_flags = 0; - coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + auto option = ProviderOptions(); + option[kCoremlProviderOption_MLComputeUnits] = "CPUOnly"; if (use_mlprogram) { - coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + option[kCoremlProviderOption_ModelFormat] = "MLProgram"; } - return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); + return CoreMLProviderFactoryCreator::Create(option)->CreateProvider(); #else ORT_UNUSED_PARAMETER(use_mlprogram); return nullptr; @@ -305,6 +299,10 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { std::unique_ptr DefaultWebGpuExecutionProvider() { #ifdef USE_WEBGPU ConfigOptions config_options{}; + // Disable storage buffer cache + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled) + .IsOK()); return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else return nullptr; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 1fd9894e09d4e..9b44150d972db 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -20,7 +20,6 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); -// std::shared_ptr CreateExecutionProviderFactory_Tvm(const char*); std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); @@ -49,7 +48,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params); std::unique_ptr DefaultMIGraphXExecutionProvider(); std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params); -std::unique_ptr OpenVINOExecutionProviderWithOptions(const OrtOpenVINOProviderOptions* params); +std::unique_ptr OpenVINOExecutionProviderWithOptions(const ProviderOptions* params, const SessionOptions* session_options = nullptr); std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultVSINPUExecutionProvider(); diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index a73b237ae10df..01be1a444646b 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -7,9 +7,6 @@ #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" #endif -#ifdef USE_TVM -#include "core/providers/tvm/tvm_provider_factory.h" -#endif #ifdef USE_OPENVINO #include "core/providers/openvino/openvino_provider_factory.h" #endif diff --git a/onnxruntime/test/wasm/package-lock.json b/onnxruntime/test/wasm/package-lock.json index 522e96fc3188a..3bd5d173dbe79 100644 --- a/onnxruntime/test/wasm/package-lock.json +++ b/onnxruntime/test/wasm/package-lock.json @@ -27,9 +27,9 @@ } }, "node_modules/@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "node_modules/@types/cookie": { @@ -39,19 +39,22 @@ "dev": true }, "node_modules/@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "dependencies": { "@types/node": "*" } }, "node_modules/@types/node": { - "version": "18.13.0", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.13.0.tgz", - "integrity": "sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==", - "dev": true + "version": "22.10.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.10.1.tgz", + "integrity": "sha512-qKgsUwfHZV2WCWLAnVP1JqnpE6Im6h3Y0+fYgMTasNQ7V++CBX5OT1as0g0f+OyubbFqhf6XVNIsmN4IIhEgGQ==", + "dev": true, + "dependencies": { + "undici-types": "~6.20.0" + } }, "node_modules/accepts": { "version": "1.3.8", @@ -162,12 +165,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -288,9 +291,9 @@ } }, "node_modules/cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true, "engines": { "node": ">= 0.6" @@ -409,9 +412,9 @@ } }, "node_modules/engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -419,32 +422,32 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true, "engines": { "node": ">=10.0.0" } }, "node_modules/engine.io/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -456,9 +459,9 @@ } }, "node_modules/engine.io/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/ent": { @@ -516,9 +519,9 @@ "dev": true }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -1304,35 +1307,60 @@ } }, "node_modules/socket.io": { - "version": "4.6.0", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.0.tgz", - "integrity": "sha512-b65bp6INPk/BMMrIgVvX12x3Q+NqlGqSlTuvKQWt0BUJ3Hyy3JangBl7fEoWZTXbOKlCqNPbQ6MbWgok/km28w==", + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.1.tgz", + "integrity": "sha512-oZ7iUCxph8WYRHHcjBEc9unw3adt5CmSNlppj/5Q4k2RIrhl8Z5yY2Xr4j9zj0+wzVZ0bxmYoGSzKJnRl6A4yg==", "dev": true, "dependencies": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.0", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", + "dev": true, + "dependencies": { + "debug": "~4.3.4", + "ws": "~8.17.1" + } + }, + "node_modules/socket.io-adapter/node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ws": "~8.11.0" + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } } }, + "node_modules/socket.io-adapter/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + }, "node_modules/socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "dependencies": { "@socket.io/component-emitter": "~3.1.0", @@ -1343,12 +1371,12 @@ } }, "node_modules/socket.io-parser/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -1360,9 +1388,9 @@ } }, "node_modules/socket.io-parser/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/socket.io/node_modules/debug": { @@ -1534,6 +1562,12 @@ "node": "*" } }, + "node_modules/undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "dev": true + }, "node_modules/universalify": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", @@ -1615,16 +1649,16 @@ "dev": true }, "node_modules/ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" }, "peerDependencies": { "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" + "utf-8-validate": ">=5.0.2" }, "peerDependenciesMeta": { "bufferutil": { @@ -1686,9 +1720,9 @@ "dev": true }, "@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "@types/cookie": { @@ -1698,19 +1732,22 @@ "dev": true }, "@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "requires": { "@types/node": "*" } }, "@types/node": { - "version": "18.13.0", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.13.0.tgz", - "integrity": "sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==", - "dev": true + "version": "22.10.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.10.1.tgz", + "integrity": "sha512-qKgsUwfHZV2WCWLAnVP1JqnpE6Im6h3Y0+fYgMTasNQ7V++CBX5OT1as0g0f+OyubbFqhf6XVNIsmN4IIhEgGQ==", + "dev": true, + "requires": { + "undici-types": "~6.20.0" + } }, "accepts": { "version": "1.3.8", @@ -1796,12 +1833,12 @@ } }, "braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "requires": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" } }, "bytes": { @@ -1890,9 +1927,9 @@ "dev": true }, "cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true }, "cors": { @@ -1986,9 +2023,9 @@ "dev": true }, "engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "requires": { "@types/cookie": "^0.4.1", @@ -1996,34 +2033,34 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "dependencies": { "debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "requires": { - "ms": "2.1.2" + "ms": "^2.1.3" } }, "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true } } }, "engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true }, "ent": { @@ -2072,9 +2109,9 @@ "dev": true }, "fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "requires": { "to-regex-range": "^5.0.1" @@ -2651,17 +2688,18 @@ } }, "socket.io": { - "version": "4.6.0", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.0.tgz", - "integrity": "sha512-b65bp6INPk/BMMrIgVvX12x3Q+NqlGqSlTuvKQWt0BUJ3Hyy3JangBl7fEoWZTXbOKlCqNPbQ6MbWgok/km28w==", + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.1.tgz", + "integrity": "sha512-oZ7iUCxph8WYRHHcjBEc9unw3adt5CmSNlppj/5Q4k2RIrhl8Z5yY2Xr4j9zj0+wzVZ0bxmYoGSzKJnRl6A4yg==", "dev": true, "requires": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.0", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "dependencies": { "debug": { @@ -2682,18 +2720,36 @@ } }, "socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "requires": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" + }, + "dependencies": { + "debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "dev": true, + "requires": { + "ms": "^2.1.3" + } + }, + "ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + } } }, "socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "requires": { "@socket.io/component-emitter": "~3.1.0", @@ -2701,18 +2757,18 @@ }, "dependencies": { "debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "requires": { - "ms": "2.1.2" + "ms": "^2.1.3" } }, "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true } } @@ -2817,6 +2873,12 @@ "integrity": "sha512-s8ax/CeZdK9R/56Sui0WM6y9OFREJarMRHqLB2EwkovemBxNQ+Bqu8GAsUnVcXKgphb++ghr/B2BZx4mahujPw==", "dev": true }, + "undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "dev": true + }, "universalify": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", @@ -2874,9 +2936,9 @@ "dev": true }, "ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} }, diff --git a/onnxruntime/test/webgpu/external_dawn/main.cc b/onnxruntime/test/webgpu/external_dawn/main.cc new file mode 100644 index 0000000000000..ed8d2eab94ce9 --- /dev/null +++ b/onnxruntime/test/webgpu/external_dawn/main.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. + +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include + +#include "dawn/native/DawnNative.h" + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + bool no_proc_table = argc > 0 && +#ifdef _WIN32 + wcscmp(L"--no_proc_table", argv[argc - 1]) == 0; +#else + strcmp("--no_proc_table", argv[argc - 1]) == 0; +#endif + + int retval = 0; + Ort::Env env{nullptr}; + try { + env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "Default"}; + + // model is https://github.com/onnx/onnx/blob/v1.15.0/onnx/backend/test/data/node/test_abs/model.onnx + constexpr uint8_t MODEL_DATA[] = {8, 7, 18, 12, 98, 97, 99, 107, 101, 110, + 100, 45, 116, 101, 115, 116, 58, 73, 10, 11, + 10, 1, 120, 18, 1, 121, 34, 3, 65, 98, + 115, 18, 8, 116, 101, 115, 116, 95, 97, 98, + 115, 90, 23, 10, 1, 120, 18, 18, 10, 16, + 8, 1, 18, 12, 10, 2, 8, 3, 10, 2, + 8, 4, 10, 2, 8, 5, 98, 23, 10, 1, + 121, 18, 18, 10, 16, 8, 1, 18, 12, 10, + 2, 8, 3, 10, 2, 8, 4, 10, 2, 8, + 5, 66, 4, 10, 0, 16, 13}; + + Ort::SessionOptions session_options; + session_options.DisableMemPattern(); + std::unordered_map provider_options; + if (!no_proc_table) { + provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); + } + session_options.AppendExecutionProvider("WebGPU", provider_options); + Ort::Session session{env, MODEL_DATA, sizeof(MODEL_DATA), session_options}; + + if (no_proc_table) { + std::cerr << "DawnProcTable is not passing to ONNX Runtime, but no exception is thrown." << std::endl; + retval = -1; + } else { + // successfully initialized + std::cout << "Successfully initialized WebGPU EP." << std::endl; + retval = 0; + } + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + + if (no_proc_table) { + std::cout << "DawnProcTable is not passing to ONNX Runtime, so an exception is thrown as expected." << std::endl; + retval = 0; + } else { + std::cerr << "Unexpected exception." << std::endl; + retval = -1; + } + } + + ::google::protobuf::ShutdownProtobufLibrary(); + return retval; +} diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index c02f18d906cbe..785c9cd937022 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -34,7 +34,8 @@ using testing::TestPartResult; using testing::UnitTest; -void ortenv_setup(); +extern "C" void ortenv_setup(); +extern "C" void ortenv_teardown(); static NSString* const GoogleTestDisabledPrefix = @"DISABLED_"; @@ -63,24 +64,51 @@ public: XCTestListener(XCTestCase* testCase) : _testCase(testCase) {} - void OnTestPartResult(const TestPartResult& test_part_result) { + void OnTestPartResult(const TestPartResult& test_part_result) override { if (test_part_result.passed() || test_part_result.skipped()) return; int lineNumber = test_part_result.line_number(); const char* fileName = test_part_result.file_name(); NSString* path = fileName ? [@(fileName) stringByStandardizingPath] : nil; + NSString* summary = @(test_part_result.summary()); NSString* description = @(test_part_result.message()); - [_testCase recordFailureWithDescription:description - inFile:path - atLine:(lineNumber >= 0 ? (NSUInteger)lineNumber : 0) - expected:YES]; + + XCTSourceCodeLocation* sourceCodeLocation = + [[XCTSourceCodeLocation alloc] initWithFilePath:path + lineNumber:lineNumber]; + + XCTSourceCodeContext* sourceCodeContext = + [[XCTSourceCodeContext alloc] initWithLocation:sourceCodeLocation]; + + XCTIssue* issue = [[XCTIssue alloc] initWithType:XCTIssueTypeAssertionFailure + compactDescription:summary + detailedDescription:description + sourceCodeContext:sourceCodeContext + associatedError:nil + attachments:@[]]; + + [_testCase recordIssue:issue]; } private: XCTestCase* _testCase; }; +/** + * A Google Test listener that manages the ORT env setup and teardown. + */ +class OrtEnvManagementListener : public testing::EmptyTestEventListener { + public: + void OnTestProgramStart(const UnitTest& unit_test) override { + ortenv_setup(); + } + + void OnTestProgramEnd(const UnitTest& unit_test) override { + ortenv_teardown(); + } +}; + /** * Registers an XCTestCase subclass for each Google Test case. * @@ -179,7 +207,6 @@ + (void)load { object:bundle queue:nil usingBlock:^(NSNotification* notification) { - ortenv_setup(); [self registerTestClasses]; }]; } @@ -201,6 +228,8 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); + listeners.Append(new OrtEnvManagementListener()); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 5173125cb8634..7adfc6a2b2ccb 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -28,7 +28,9 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); +#ifndef ORT_WASM64 static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); +#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { @@ -94,9 +96,10 @@ int OrtInit(int num_threads, int logging_level) { #endif } -void OrtGetLastError(int* error_code, const char** error_message) { +int OrtGetLastError(int* error_code, const char** error_message) { *error_code = g_last_error_code; *error_message = g_last_error_message.empty() ? nullptr : g_last_error_message.c_str(); + return ORT_OK; } OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, @@ -177,8 +180,9 @@ int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, return CHECK_STATUS(AddSessionConfigEntry, session_options, config_key, config_value); } -void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { +int OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); + return ORT_OK; } OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { @@ -196,8 +200,9 @@ OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* : nullptr; } -void OrtReleaseSession(OrtSession* session) { +int OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); + return ORT_OK; } int OrtGetInputOutputCount(OrtSession* session, size_t* input_count, size_t* output_count) { @@ -226,11 +231,12 @@ char* OrtGetOutputName(OrtSession* session, size_t index) { : nullptr; } -void OrtFree(void* ptr) { +int OrtFree(void* ptr) { OrtAllocator* allocator = nullptr; if (CHECK_STATUS(GetAllocatorWithDefaultOptions, &allocator) == ORT_OK) { allocator->Free(allocator, ptr); } + return ORT_OK; } OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { @@ -287,7 +293,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* } } -int OrtGetTensorData(OrtValue* tensor, int* data_type, void** data, size_t** dims, size_t* dims_length) { +int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length) { ONNXType tensor_type; RETURN_ERROR_CODE_IF_ERROR(GetValueType, tensor, &tensor_type); if (tensor_type != ONNX_TYPE_TENSOR) { @@ -357,14 +363,15 @@ int OrtGetTensorData(OrtValue* tensor, int* data_type, void** data, size_t** dim *data = p_tensor_raw_data; } - *data_type = static_cast(type); + *data_type = static_cast(type); *dims_length = dims_len; *dims = UNREGISTER_AUTO_RELEASE(p_dims); return ORT_OK; } -void OrtReleaseTensor(OrtValue* tensor) { +int OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); + return ORT_OK; } OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, @@ -399,8 +406,9 @@ int OrtAddRunConfigEntry(OrtRunOptions* run_options, return CHECK_STATUS(AddRunConfigEntry, run_options, config_key, config_value); } -void OrtReleaseRunOptions(OrtRunOptions* run_options) { +int OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); + return ORT_OK; } OrtIoBinding* OrtCreateBinding(OrtSession* session) { @@ -445,12 +453,14 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, } } -void OrtClearBoundOutputs(OrtIoBinding* io_binding) { +int OrtClearBoundOutputs(OrtIoBinding* io_binding) { Ort::GetApi().ClearBoundOutputs(io_binding); + return ORT_OK; } -void OrtReleaseBinding(OrtIoBinding* io_binding) { +int OrtReleaseBinding(OrtIoBinding* io_binding) { Ort::GetApi().ReleaseIoBinding(io_binding); + return ORT_OK; } int OrtRunWithBinding(OrtSession* session, @@ -520,8 +530,9 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( : nullptr; } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { Ort::GetTrainingApi().ReleaseCheckpointState(training_checkpoint_state_handle); + return ORT_OK; } ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(const ort_session_options_handle_t options, @@ -640,8 +651,9 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi } } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); + return ORT_OK; } #endif diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 0730559c4375b..f44c515d98f6b 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -50,7 +50,7 @@ int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level); * @param error_code [out] a pointer to accept the error code. * @param error_message [out] a pointer to accept the error message. The message buffer is only available before any ORT API is called. */ -void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); +int EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); /** * create an instance of ORT session options. @@ -109,7 +109,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t s /** * release the specified ORT session options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); /** * create an instance of ORT session. @@ -124,7 +124,7 @@ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, /** * release the specified ORT session. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); /** * get model's input count and output count. @@ -158,7 +158,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t * free the specified buffer. * @param ptr a pointer to the buffer. */ -void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); +int EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); /** * create an instance of ORT tensor. @@ -183,12 +183,12 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ -int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length); +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); /** * release the specified tensor. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); +int EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); /** * create an instance of ORT run options. @@ -218,7 +218,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio /** * release the specified ORT run options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); /** * create an instance of ORT IO binding. @@ -252,12 +252,12 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, /** * clear all bound outputs. */ -void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); /** * release the specified ORT IO binding. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); /** * inference the model. @@ -311,7 +311,7 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( * * @param training_checkpoint_state_handle handle for the CheckpointState */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); /** * Creates an instance of a training session that can be used to begin or resume training from a given checkpoint state @@ -466,7 +466,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi * * @param training_session_handle handle of the training session */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); #endif }; diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js new file mode 100644 index 0000000000000..b77d82fbd7d10 --- /dev/null +++ b/onnxruntime/wasm/js_post_js.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js new file mode 100644 index 0000000000000..b140df927ebbd --- /dev/null +++ b/onnxruntime/wasm/js_post_js_64.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 8; diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..45e2475548df5 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -192,6 +192,9 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnCreateSession'] = sessionId => { + backend['onCreateSession'](sessionId); + }; Module['jsepOnReleaseSession'] = sessionId => { backend['onReleaseSession'](sessionId); }; @@ -234,6 +237,13 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; + Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { + return backend['registerMLConstant']( + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + }; } }; diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 87a7cbc0375a4..f1545e96481fa 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -758,7 +758,8 @@ Status TrainingSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/, + const logging::Logger& /*logger*/) const { ORT_RETURN_IF_NOT( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, "Only applying full build optimizations is supported by TrainingSession."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 765f88e1c992e..58492dc62400f 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -489,7 +489,8 @@ class TrainingSession : public InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index c98e5bcd97092..31591c0156b14 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -185,10 +185,13 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(loaded_model, model_path), ( - onnxblock.custom_op_library(custom_op_library_path) - if custom_op_library is not None - else contextlib.nullcontext() + with ( + onnxblock.base(loaded_model, model_path), + ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ), ): _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index ed68171cc6f9c..c13843f816f16 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -54,8 +54,10 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) if accessor._GLOBAL_ACCESSOR.has_path: + # `save` will destructively access any external data + copied_model = copy.deepcopy(accessor._GLOBAL_ACCESSOR.model) onnx.save( - accessor._GLOBAL_ACCESSOR.model, + copied_model, self.temp_onnx_file_path, save_as_external_data=True, all_tensors_to_one_file=True, diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 22627749c316c..d9cae8e1f99e8 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -867,8 +867,9 @@ def _get_exported_model( assert model_info_for_export.export_mode is not None, "Please use a concrete instance of ExecutionManager" try: - with torch.no_grad(), stage3_export_context( - enable_zero_stage3_support, stage3_param_handle, flattened_module + with ( + torch.no_grad(), + stage3_export_context(enable_zero_stage3_support, stage3_param_handle, flattened_module), ): required_export_kwargs = { "input_names": model_info_for_export.onnx_graph_input_names, # did not contains parameters as its input yet diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 0944e46ff8eaf..58c173ed90277 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -139,7 +139,8 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; - auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci); + auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, + DefaultLoggingManager().DefaultLogger(), &kci); if (!st.IsOK()) { // The goal here is unclear. It seems best to leave it to the Session // creation to figure out whether the model can be executed using some diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..1b8699d1de497 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -23,8 +23,10 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +49,8 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 5c63be92d2b2f..0866d4a411e29 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1159,7 +1159,8 @@ def test_generate_artifacts_external_data_one_file(): assert os.path.exists(os.path.join(temp_dir, "checkpoint")) -def test_generate_artifacts_external_data_separate_files(): +@pytest.mark.parametrize("loss", [loss_t for loss_t in artifacts.LossType]) +def test_generate_artifacts_external_data_separate_files(loss): with tempfile.TemporaryDirectory() as temp_dir: _, simple_net = _get_models("cpu", 32, 28, 10, 10) @@ -1176,7 +1177,7 @@ def test_generate_artifacts_external_data_separate_files(): artifacts.generate_artifacts( os.path.join(temp_dir, "simple_net.onnx"), requires_grad=requires_grad_params, - loss=artifacts.LossType.CrossEntropyLoss, + loss=loss, optimizer=artifacts.OptimType.AdamW, artifact_directory=temp_dir, ) diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 9b30bd128b161..d4f7fbf2080ce 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cuda/nn/conv_shared.h" -#include "core/platform/ort_mutex.h" +#include #include "core/providers/common.h" #include "core/providers/cuda/cuda_kernel.h" @@ -65,11 +65,11 @@ std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { template struct AlgoPerfCache { - mutable OrtMutex mutex; + mutable std::mutex mutex; std::unordered_map map; bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); + std::lock_guard guard(mutex); auto it = map.find(params); if (it == map.end()) { return false; @@ -79,7 +79,7 @@ struct AlgoPerfCache { } void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); + std::lock_guard guard(mutex); map[params] = algo_perf; } }; diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc index 22fa5b6f55a5d..3b1ed29cb0240 100644 --- a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc @@ -7,7 +7,7 @@ #include "core/providers/common.h" #include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { namespace rocm { @@ -96,11 +96,11 @@ struct ConvParamsEqual { template struct AlgoPerfCache { - mutable OrtMutex mutex; + mutable std::mutex mutex; std::unordered_map map; bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); + std::lock_guard guard(mutex); auto it = map.find(params); if (it == map.end()) { return false; @@ -110,7 +110,7 @@ struct AlgoPerfCache { } void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); + std::lock_guard guard(mutex); map[params] = algo_perf; } }; diff --git a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch index 3a408e2265fe7..29b8812c979e4 100644 --- a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch +++ b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch @@ -46,7 +46,7 @@ RUN cd MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 && \ rm -r MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 ENV PATH=${OLD_PATH} -ENV unset OLD_PATH +ENV unset=OLD_PATH # python env RUN pip3 install --upgrade setuptools diff --git a/packages.config b/packages.config index 597ca77a321c5..877e2a17fd83e 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/pyproject.toml b/pyproject.toml index 6429df2722b2d..40e6eb96dff94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,8 @@ line-length = 120 # NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead extend-exclude = "cmake|onnxruntime/core/flatbuffers/" -target-version = ["py37", "py38", "py39", "py310", "py311"] +# NOTE: use the minimum supported python version as target-version +target-version = ["py310"] [tool.isort] # NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead diff --git a/requirements-dev.txt b/requirements-dev.txt index 1b5ca65cf8037..b95b85781a398 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,6 @@ -black>=22.3 +-r requirements-lintrunner.txt cerberus flatbuffers -isort jinja2 numpy onnx diff --git a/setup.py b/setup.py index 034f63caf42b3..c1580eeb9e8f9 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,6 @@ def parse_arg_remove_string(argv, arg_name_equal): # Any combination of the following arguments can be applied if parse_arg_remove_boolean(sys.argv, "--nightly_build"): - package_name = "ort-nightly" nightly_build = True wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=") @@ -89,10 +88,10 @@ def parse_arg_remove_string(argv, arg_name_equal): pass elif parse_arg_remove_boolean(sys.argv, "--use_qnn"): is_qnn = True - package_name = "onnxruntime-qnn" if not nightly_build else "ort-nightly-qnn" + package_name = "onnxruntime-qnn" if is_rocm or is_migraphx: - package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" + package_name = "onnxruntime-rocm" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 @@ -224,7 +223,7 @@ def run(self): "libcudnn_heuristic.so.9", "libcudnn_ops.so.9", "libnvJitLink.so.12", - "libnvrtc.so.11", + "libnvrtc.so.11.2", # A symlink to libnvrtc.so.11.8.89 "libnvrtc.so.12", "libnvrtc-builtins.so.11", "libnvrtc-builtins.so.12", @@ -244,6 +243,7 @@ def run(self): "libMIOpen.so.1", "libnuma.so.1", "librccl.so.1", + "libhipblas.so.2", "librocblas.so.3", "librocblas.so.4", "librocfft.so.0", @@ -529,6 +529,8 @@ def finalize_options(self): "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", @@ -537,14 +539,10 @@ def finalize_options(self): "Topic :: Software Development :: Libraries :: Python Modules", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Operating System :: Microsoft :: Windows", - "Operating System :: MacOS", + "Programming Language :: Python :: 3.13", ] if enable_training or enable_training_apis: diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 0ff365dd5ff74..fcaffd9ef5e78 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -15,7 +15,7 @@ RUN apt-get update && apt-get install --yes --no-install-recommends \ ca-certificates \ git \ ninja-build \ - openjdk-11-jdk-headless \ + openjdk-17-jdk-headless \ python3-dev \ python3-numpy \ python3-pip \ diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index f1d3702e3245e..6a8154681ed97 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -11,7 +11,7 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): if not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # Run hipify-perl first, capture output - s = subprocess.run([hipify_perl_path, "-roc", src_file_path], stdout=subprocess.PIPE, text=True, check=False).stdout + s = subprocess.run([hipify_perl_path, src_file_path], stdout=subprocess.PIPE, text=True, check=False).stdout # Additional exact-match replacements. # Order matters for all of the following replacements, reglardless of appearing in logical sections. @@ -22,22 +22,13 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") s = s.replace("kTotalCudaStreams", "kTotalHipStreams") - # We want rocblas interfaces, not hipblas. Also force some hipify replacements back to rocblas from hipblas. - s = s.replace("CublasHandle", "RocblasHandle") - s = s.replace("cublas_handle", "rocblas_handle") - s = s.replace("hipblasHandle_t", "rocblas_handle") - s = s.replace("hipblasDatatype_t", "rocblas_datatype") - s = s.replace("HIPBLAS_STATUS_SUCCESS", "rocblas_status_success") - s = s.replace("hipblasStatus_t", "rocblas_status") - s = s.replace("hipblasCreate", "rocblas_create_handle") - s = s.replace("hipblasDestroy", "rocblas_destroy_handle") - s = s.replace("hipblasSetStream", "rocblas_set_stream") - s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose") - s = s.replace("HIPBLAS_OP_N", "rocblas_operation_none") - # in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want s = s.replace("rocblas_half", "__half") + # these should be "hip" but it's easier to just use rocm to avoid complicated file renaming + s = s.replace("CudaGraph", "RocmGraph") + s = s.replace("CUDAGraph", "ROCMGraph") + s = s.replace("cuda_graph", "rocm_graph") s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") s = s.replace("cudaEvent", "hipEvent") s = s.replace("CreateCudaAllocator", "CreateRocmAllocator") @@ -100,24 +91,15 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("typedef half MappedType", "typedef __half MappedType") # CUBLAS -> HIPBLAS - # Note: We do not use the hipblas marshalling interfaces; use rocblas instead. - # s = s.replace('CUBLAS', 'HIPBLAS') - # s = s.replace('Cublas', 'Hipblas') - # s = s.replace('cublas', 'hipblas') - - # CUBLAS -> ROCBLAS - s = s.replace("CUBLAS", "ROCBLAS") - s = s.replace("Cublas", "Rocblas") - s = s.replace("cublas", "rocblas") + s = s.replace("CUBLAS", "HIPBLAS") + s = s.replace("Cublas", "Hipblas") + s = s.replace("cublas", "hipblas") + # deprecated cublas symbol doesn't exist in hipblas, map to new symbol + s = s.replace("HIPBLAS_GEMM_DEFAULT_TENSOR_OP", "HIPBLAS_GEMM_DEFAULT") # Undefined ROCMRT constants -> std::numeric_limits s = s.replace("ROCMRT_INF_F", "std::numeric_limits::infinity()") - # HIPBLAS -> rocblas - s = s.replace("HIPBLAS_R_16F", "rocblas_datatype_f16_r") - s = s.replace("HIPBLAS_R_32F", "rocblas_datatype_f32_r") - s = s.replace("ROCBLAS_GEMM_DEFAULT_TENSOR_OP", "rocblas_gemm_algo_standard") - # compatible layer s = s.replace("rocblas_gemm_strided_batched_ex", "_compat_rocblas_gemm_strided_batched_ex") s = s.replace("RocblasMathModeSetter", "CompatRocblasMathModeSetter") diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4cef66b26f0f4..3527a89ca7a7b 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -13,6 +13,7 @@ import shutil import subprocess import sys +import warnings from pathlib import Path @@ -253,7 +254,12 @@ def convert_arg_line_to_args(self, arg_line): "--cudnn_home is not specified.", ) parser.add_argument("--enable_cuda_line_info", action="store_true", help="Enable CUDA line info.") - parser.add_argument("--enable_cuda_nhwc_ops", action="store_true", help="Enable CUDA NHWC ops in build.") + + parser.add_argument( + "--enable_cuda_nhwc_ops", action="store_true", help="Deprecated; default to enable CUDA NHWC ops in build." + ) + + parser.add_argument("--disable_cuda_nhwc_ops", action="store_true", help="Disable CUDA NHWC ops in build.") # Python bindings parser.add_argument("--enable_pybind", action="store_true", help="Enable Python Bindings.") @@ -399,7 +405,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with a specific GDK edition. Defaults to the latest installed.", ) parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") - + parser.add_argument("--enable_wasm_memory64", action="store_true", help="Enable WebAssembly 64bit support") platform_group = parser.add_mutually_exclusive_group() platform_group.add_argument("--ios", action="store_true", help="build for ios") platform_group.add_argument("--visionos", action="store_true", help="build for visionOS") @@ -571,19 +577,14 @@ def convert_arg_line_to_args(self, arg_line): ) parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") + parser.add_argument("--use_external_dawn", action="store_true", help="Treat Dawn as an external dependency.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") parser.add_argument("--use_preinstalled_eigen", action="store_true", help="Use pre-installed Eigen.") parser.add_argument("--eigen_path", help="Path to pre-installed Eigen.") parser.add_argument("--enable_msinternal", action="store_true", help="Enable for Microsoft internal builds only.") - parser.add_argument("--llvm_path", help="Path to llvm dir") parser.add_argument("--use_vitisai", action="store_true", help="Build with Vitis-AI") - parser.add_argument("--use_tvm", action="store_true", help="Build with TVM") - parser.add_argument("--tvm_cuda_runtime", action="store_true", default=False, help="Build TVM with CUDA support") - parser.add_argument( - "--use_tvm_hash", action="store_true", help="Build ipp-crypto for hash generation. It is used by TVM EP only" - ) parser.add_argument("--use_tensorrt", action="store_true", help="Build with TensorRT") parser.add_argument( "--use_tensorrt_builtin_parser", action="store_true", default=True, help="Use TensorRT builtin parser" @@ -595,12 +596,6 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--migraphx_home", help="Path to MIGraphX installation dir") parser.add_argument("--use_full_protobuf", action="store_true", help="Use the full protobuf library") - parser.add_argument( - "--llvm_config", - type=str, - default="", - help="Path to llvm-config.exe for LLVM built from sources. It is strongly needed for build on Windows", - ) parser.add_argument( "--skip_onnx_tests", action="store_true", @@ -761,6 +756,7 @@ def convert_arg_line_to_args(self, arg_line): ) parser.add_argument("--use_xnnpack", action="store_true", help="Enable xnnpack EP.") + parser.add_argument("--use_avx512", action="store_true", help="Enable AVX512 instructions") parser.add_argument("--use_azure", action="store_true", help="Enable azure EP.") parser.add_argument("--use_cache", action="store_true", help="Use compiler cache in CI") @@ -791,6 +787,11 @@ def convert_arg_line_to_args(self, arg_line): if args.cmake_generator is None and is_windows(): args.cmake_generator = "Ninja" if args.build_wasm else "Visual Studio 17 2022" + if args.enable_cuda_nhwc_ops: + warnings.warn( + "The argument '--enable_cuda_nhwc_ops' is deprecated and is default to True. ", DeprecationWarning + ) + return args @@ -1018,16 +1019,11 @@ def generate_build_tree( "-Donnxruntime_USE_NNAPI_BUILTIN=" + ("ON" if args.use_nnapi else "OFF"), "-Donnxruntime_USE_VSINPU=" + ("ON" if args.use_vsinpu else "OFF"), "-Donnxruntime_USE_RKNPU=" + ("ON" if args.use_rknpu else "OFF"), - "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), "-Donnxruntime_USE_VITISAI=" + ("ON" if args.use_vitisai else "OFF"), "-Donnxruntime_USE_TENSORRT=" + ("ON" if args.use_tensorrt else "OFF"), "-Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=" + ("ON" if args.use_tensorrt_builtin_parser and not args.use_tensorrt_oss_parser else "OFF"), - # set vars for TVM - "-Donnxruntime_USE_TVM=" + ("ON" if args.use_tvm else "OFF"), - "-Donnxruntime_TVM_CUDA_RUNTIME=" + ("ON" if args.use_tvm and args.tvm_cuda_runtime else "OFF"), - "-Donnxruntime_TVM_USE_HASH=" + ("ON" if args.use_tvm_hash else "OFF"), # set vars for migraphx "-Donnxruntime_USE_MIGRAPHX=" + ("ON" if args.use_migraphx else "OFF"), "-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"), @@ -1057,6 +1053,7 @@ def generate_build_tree( "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), "-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"), "-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"), + "-Donnxruntime_USE_EXTERNAL_DAWN=" + ("ON" if args.use_external_dawn else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), @@ -1071,7 +1068,7 @@ def generate_build_tree( "-Donnxruntime_USE_MPI=" + ("ON" if args.use_mpi else "OFF"), "-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"), "-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"), - "-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.enable_cuda_nhwc_ops else "OFF"), + "-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.use_cuda and not args.disable_cuda_nhwc_ops else "OFF"), "-Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=" + ("ON" if args.build_wasm_static_lib else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=" + ("OFF" if args.disable_wasm_exception_catching else "ON"), @@ -1081,6 +1078,7 @@ def generate_build_tree( + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_MEMORY64=" + ("ON" if args.enable_wasm_memory64 else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"), @@ -1157,8 +1155,6 @@ def generate_build_tree( cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) - if args.llvm_config: - cmake_args.append("-Donnxruntime_TVM_USE_LLVM=" + args.llvm_config) if args.use_cuda: add_default_definition(cmake_extra_defines, "onnxruntime_USE_CUDA", "ON") @@ -1176,11 +1172,14 @@ def generate_build_tree( ) add_default_definition(cmake_extra_defines, "ONNX_USE_MSVC_STATIC_RUNTIME", "ON") add_default_definition(cmake_extra_defines, "protobuf_MSVC_STATIC_RUNTIME", "ON") + # The following build option was added in ABSL 20240722.0 and it must be explicitly set + add_default_definition(cmake_extra_defines, "ABSL_MSVC_STATIC_RUNTIME", "ON") add_default_definition(cmake_extra_defines, "gtest_force_shared_crt", "OFF") else: # CMAKE_MSVC_RUNTIME_LIBRARY is default to MultiThreaded$<$:Debug>DLL add_default_definition(cmake_extra_defines, "ONNX_USE_MSVC_STATIC_RUNTIME", "OFF") add_default_definition(cmake_extra_defines, "protobuf_MSVC_STATIC_RUNTIME", "OFF") + add_default_definition(cmake_extra_defines, "ABSL_MSVC_STATIC_RUNTIME", "OFF") add_default_definition(cmake_extra_defines, "gtest_force_shared_crt", "ON") if acl_home and os.path.exists(acl_home): @@ -1238,9 +1237,6 @@ def generate_build_tree( if args.use_full_protobuf or args.use_openvino or args.use_vitisai or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] - if args.use_tvm and args.llvm_path is not None: - cmake_args += [f"-DLLVM_DIR={args.llvm_path}"] - if args.use_cuda and not is_windows(): nvml_stub_path = cuda_home + "/lib64/stubs" cmake_args += ["-DCUDA_CUDA_LIBRARY=" + nvml_stub_path] @@ -1316,6 +1312,9 @@ def generate_build_tree( if args.use_jsep and args.use_webgpu: raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_external_dawn and not args.use_webgpu: + raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] @@ -1548,11 +1547,7 @@ def generate_build_tree( and not args.build_wasm ): if is_windows(): - # DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler - # See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856 - # Remove this definition (_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) - # once the conda msvcp140.dll dll is updated. - cflags += ["/guard:cf", "/DWIN32", "/D_WINDOWS", "/D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR"] + cflags += ["/guard:cf", "/DWIN32", "/D_WINDOWS"] if not args.use_gdk: # Target Windows 10 cflags += [ @@ -1566,8 +1561,7 @@ def generate_build_tree( ldflags = ["/profile", "/DYNAMICBASE"] # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. if not args.enable_address_sanitizer: - # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs - cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] + cflags += ["/Qspectre"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": @@ -1642,16 +1636,6 @@ def generate_build_tree( cxxflags = cflags.copy() config_build_dir = get_config_build_dir(build_dir, config) os.makedirs(config_build_dir, exist_ok=True) - if args.use_tvm: - os.environ["PATH"] = ( - os.path.join(config_build_dir, "_deps", "tvm-build") - + os.pathsep - + os.path.join(config_build_dir, "_deps", "tvm-src") - + os.pathsep - + os.path.dirname(sys.executable) - + os.pathsep - + os.environ["PATH"] - ) preinstalled_dir = Path(build_dir) / config temp_cmake_args = cmake_args.copy() if cflags is not None and cxxflags is not None and len(cflags) != 0 and len(cxxflags) != 0: @@ -1866,6 +1850,10 @@ def setup_rocm_build(args): def run_android_tests(args, source_dir, build_dir, config, cwd): + if args.android_abi != "x86_64": + log.info(f"--android_abi ({args.android_abi}) is not x86_64, skipping running of Android tests on emulator.") + return + sdk_tool_paths = android.get_sdk_tool_paths(args.android_sdk_path) device_dir = "/data/local/tmp" @@ -1887,72 +1875,85 @@ def run_adb_shell(cmd): else: adb_shell(f"cd {device_dir} && {cmd}") - if args.android_abi == "x86_64": - with contextlib.ExitStack() as context_stack: - if args.android_run_emulator: - avd_name = "ort_android" - system_image = f"system-images;android-{args.android_api};default;{args.android_abi}" - - android.create_virtual_device(sdk_tool_paths, system_image, avd_name) - emulator_proc = context_stack.enter_context( - android.start_emulator( - sdk_tool_paths=sdk_tool_paths, - avd_name=avd_name, - extra_args=["-partition-size", "2047", "-wipe-data"], - ) + with contextlib.ExitStack() as context_stack: + if args.android_run_emulator: + avd_name = "ort_android" + system_image = f"system-images;android-{args.android_api};default;{args.android_abi}" + + android.create_virtual_device(sdk_tool_paths, system_image, avd_name) + emulator_proc = context_stack.enter_context( + android.start_emulator( + sdk_tool_paths=sdk_tool_paths, + avd_name=avd_name, + extra_args=["-partition-size", "2047", "-wipe-data"], ) - context_stack.callback(android.stop_emulator, emulator_proc) + ) + context_stack.callback(android.stop_emulator, emulator_proc) + + adb_push("testdata", device_dir, cwd=cwd) + adb_push(os.path.join(source_dir, "cmake", "external", "onnx", "onnx", "backend", "test"), device_dir, cwd=cwd) + adb_push("onnxruntime_test_all", device_dir, cwd=cwd) + adb_shell(f"chmod +x {device_dir}/onnxruntime_test_all") + adb_push("onnx_test_runner", device_dir, cwd=cwd) + adb_shell(f"chmod +x {device_dir}/onnx_test_runner") + run_adb_shell(f"{device_dir}/onnxruntime_test_all") + + # remove onnxruntime_test_all as it takes up a _lot_ of space and can cause insufficient storage errors + # when we try to copy the java app to the device. + adb_shell(f"rm {device_dir}/onnxruntime_test_all") + + if args.build_java: + # use the gradle wrapper under /java + gradle_executable = os.path.join(source_dir, "java", "gradlew.bat" if is_windows() else "gradlew") + android_test_path = os.path.join(cwd, "java", "androidtest", "android") + run_subprocess( + [ + gradle_executable, + "--no-daemon", + f"-DminSdkVer={args.android_api}", + "clean", + "connectedDebugAndroidTest", + ], + cwd=android_test_path, + ) - adb_push("testdata", device_dir, cwd=cwd) - adb_push( - os.path.join(source_dir, "cmake", "external", "onnx", "onnx", "backend", "test"), device_dir, cwd=cwd + if args.use_nnapi: + run_adb_shell(f"{device_dir}/onnx_test_runner -e nnapi {device_dir}/test") + else: + run_adb_shell(f"{device_dir}/onnx_test_runner {device_dir}/test") + + # run shared_lib_test if necessary + if args.build_shared_lib: + adb_push("libonnxruntime.so", device_dir, cwd=cwd) + adb_push("onnxruntime_shared_lib_test", device_dir, cwd=cwd) + adb_push("libcustom_op_library.so", device_dir, cwd=cwd) + adb_push("libcustom_op_get_const_input_test_library.so", device_dir, cwd=cwd) + adb_push("onnxruntime_customopregistration_test", device_dir, cwd=cwd) + adb_shell(f"chmod +x {device_dir}/onnxruntime_shared_lib_test") + adb_shell(f"chmod +x {device_dir}/onnxruntime_customopregistration_test") + run_adb_shell(f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_shared_lib_test") + run_adb_shell( + f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_customopregistration_test" ) - adb_push("onnxruntime_test_all", device_dir, cwd=cwd) - adb_shell(f"chmod +x {device_dir}/onnxruntime_test_all") - adb_push("onnx_test_runner", device_dir, cwd=cwd) - adb_shell(f"chmod +x {device_dir}/onnx_test_runner") - run_adb_shell(f"{device_dir}/onnxruntime_test_all") - - # remove onnxruntime_test_all as it takes up a _lot_ of space and can cause insufficient storage errors - # when we try to copy the java app to the device. - adb_shell(f"rm {device_dir}/onnxruntime_test_all") - - if args.build_java: - # use the gradle wrapper under /java - gradle_executable = os.path.join(source_dir, "java", "gradlew.bat" if is_windows() else "gradlew") - android_test_path = os.path.join(cwd, "java", "androidtest", "android") - run_subprocess( - [ - gradle_executable, - "--no-daemon", - f"-DminSdkVer={args.android_api}", - "clean", - "connectedDebugAndroidTest", - ], - cwd=android_test_path, - ) - if args.use_nnapi: - run_adb_shell(f"{device_dir}/onnx_test_runner -e nnapi {device_dir}/test") - else: - run_adb_shell(f"{device_dir}/onnx_test_runner {device_dir}/test") - # run shared_lib_test if necessary - if args.build_shared_lib: - adb_push("libonnxruntime.so", device_dir, cwd=cwd) - adb_push("onnxruntime_shared_lib_test", device_dir, cwd=cwd) - adb_push("libcustom_op_library.so", device_dir, cwd=cwd) - adb_push("libcustom_op_get_const_input_test_library.so", device_dir, cwd=cwd) - adb_push("onnxruntime_customopregistration_test", device_dir, cwd=cwd) - adb_shell(f"chmod +x {device_dir}/onnxruntime_shared_lib_test") - adb_shell(f"chmod +x {device_dir}/onnxruntime_customopregistration_test") - run_adb_shell(f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_shared_lib_test") - run_adb_shell( - f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_customopregistration_test" - ) +def run_ios_tests(args, source_dir, config, cwd): + is_targeting_iphone_simulator = "iphonesimulator" in args.apple_sysroot.lower() + if not is_targeting_iphone_simulator: + log.info( + f"Could not detect iphonesimulator target from --apple_sysroot ({args.apple_sysroot}), " + "skipping running of iOS tests on simulator." + ) + return + host_arch = platform.machine() + if host_arch != args.osx_arch: + log.info( + f"Host arch ({host_arch}) and --osx_arch ({args.osx_arch}) mismatch, " + "skipping running of iOS tests on simulator." + ) + return -def run_ios_tests(args, source_dir, config, cwd): simulator_device_info = subprocess.check_output( [ sys.executable, @@ -2063,8 +2064,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if args.enable_pybind: python_path = None - if args.use_tvm: - python_path = str((Path(build_dir) / config / "_deps" / "tvm-src" / "python").resolve()) # Disable python tests in a reduced build as we don't know which ops have been included and which # models can run. @@ -2074,6 +2073,17 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if is_windows(): cwd = os.path.join(cwd, config) + if args.enable_transformers_tool_test and not args.disable_contrib_ops and not args.use_rocm: + # PyTorch is required for transformers tests, and optional for some python tests. + # Install cpu only version of torch when cuda is not enabled in Linux. + extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"] + run_subprocess( + [sys.executable, "-m", "pip", "install", "torch", *extra], + cwd=cwd, + dll_path=dll_path, + python_path=python_path, + ) + run_subprocess( [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) @@ -2102,10 +2112,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if not args.disable_ml_ops and not args.use_tensorrt: run_subprocess([sys.executable, "onnxruntime_test_python_mlops.py"], cwd=cwd, dll_path=dll_path) - # if args.use_tensorrt: - # run_subprocess( - # [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path - # ) + if args.use_tensorrt: + run_subprocess( + [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path + ) try: import onnx # noqa: F401 @@ -2128,6 +2138,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): dll_path=dll_path, python_path=python_path, ) + if not args.disable_contrib_ops: run_subprocess( [sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path @@ -2149,7 +2160,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): ], cwd=SCRIPT_DIR, ) - run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "--durations=0", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version]) run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version]) @@ -2187,17 +2198,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_subprocess([sys.executable, "onnxruntime_test_python_keras.py"], cwd=cwd, dll_path=dll_path) -def tvm_run_python_tests(build_dir, configs): - for config in configs: - cwd = get_config_build_dir(build_dir, config) - if is_windows(): - cwd = os.path.join(cwd, config) - python_path = os.path.join(build_dir, config, "_deps", "tvm-src", "python") - run_subprocess( - [sys.executable, "onnxruntime_test_python_tvm.py"], cwd=cwd, python_path=os.path.abspath(python_path) - ) - - def run_nodejs_tests(nodejs_binding_dir): args = ["npm", "test", "--", "--timeout=90000"] if is_windows(): @@ -2217,7 +2217,6 @@ def build_python_wheel( use_dnnl, use_tensorrt, use_openvino, - use_tvm, use_vitisai, use_acl, use_armnn, @@ -2270,8 +2269,6 @@ def build_python_wheel( args.append("--use_openvino") elif use_dnnl: args.append("--use_dnnl") - elif use_tvm: - args.append("--use_tvm") elif use_vitisai: args.append("--use_vitisai") elif use_acl: @@ -2300,7 +2297,6 @@ def build_nuget_package( use_openvino, use_tensorrt, use_dnnl, - use_tvm, use_winml, use_qnn, enable_training_apis, @@ -2336,7 +2332,7 @@ def build_nuget_package( target_name = "/t:CreateWindowsAIPackage" elif use_openvino: execution_provider = "/p:ExecutionProvider=openvino" - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.OpenVino" + package_name = "/p:OrtPackageId=Intel.ML.OnnxRuntime.OpenVino" elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" @@ -2347,9 +2343,6 @@ def build_nuget_package( package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu" elif use_rocm: package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm" - elif use_tvm: - execution_provider = "/p:ExecutionProvider=tvm" - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Tvm" elif use_qnn: execution_provider = "/p:ExecutionProvider=qnn" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.QNN" @@ -2591,7 +2584,7 @@ def main(): if args.use_tensorrt: args.use_cuda = True - if args.build_wheel or args.gen_doc or args.use_tvm or args.enable_training: + if args.build_wheel or args.gen_doc or args.enable_training: args.enable_pybind = True if ( @@ -2873,12 +2866,6 @@ def main(): run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs) - # TODO(agladyshev): - # to support Windows, we need to update .github/workflows/windows.yml - # and add to the PATH variable the following value: C:Program Files\LLVM\bin - if args.enable_pybind and args.use_tvm and not is_windows(): - tvm_run_python_tests(build_dir, configs) - # run node.js binding tests if args.build_nodejs and not args.skip_nodejs_tests: nodejs_binding_dir = os.path.normpath(os.path.join(source_dir, "js", "node")) @@ -2906,7 +2893,6 @@ def main(): args.use_dnnl, args.use_tensorrt, args.use_openvino, - args.use_tvm, args.use_vitisai, args.use_acl, args.use_armnn, @@ -2934,7 +2920,6 @@ def main(): args.use_openvino, args.use_tensorrt, args.use_dnnl, - args.use_tvm, args.use_winml, args.use_qnn, args.enable_training_apis, diff --git a/tools/ci_build/clean_docker_image_cache.py b/tools/ci_build/clean_docker_image_cache.py deleted file mode 100755 index 8ec2b6b438176..0000000000000 --- a/tools/ci_build/clean_docker_image_cache.py +++ /dev/null @@ -1,258 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import argparse -import collections -import datetime -import json -import os -import re -import sys -import tempfile - -from logger import get_logger - -SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) -REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) - -sys.path.append(os.path.join(REPO_DIR, "tools", "python")) - - -from util import run # noqa: E402 - -log = get_logger("clean_docker_image_cache") - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Cleans the docker image cache container registry. " - "This assumes a fairly specific setup - an Azure container registry " - "and a storage account that receives " - "ContainerRegistryRepositoryEvents logs from that registry. " - "The logs are searched in order to determine whether images should be " - "retained or removed. " - "For an image to be retained, it must have been accessed at least N " - "times (specified by --cache-min-access-count) over the past K days " - "(specified by --cache-history-days)." - ) - - parser.add_argument("--container-registry", required=True, help="The container registry name.") - - parser.add_argument("--log-storage-account", required=True, help="The storage account name.") - parser.add_argument("--log-storage-account-container", required=True, help="The storage account container name.") - parser.add_argument( - "--log-storage-path-pattern", default="*.json", help="The log path pattern in the storage account container." - ) - - parser.add_argument("--cache-history-days", type=int, default=7, help="The length of the cache history in days.") - - parser.add_argument( - "--cache-min-access-count", type=int, default=1, help="The minimum access count over the cache history." - ) - - parser.add_argument("--dry-run", action="store_true", help="Do a dry-run and do not remove any images.") - - parser.add_argument("--az-path", default="az", help="Path to the az client.") - - return parser.parse_args() - - -def az(*args, parse_output=True, az_path): - proc = run(az_path, *args, "--output", "json", capture_stdout=parse_output) - if parse_output: - return json.loads(proc.stdout.decode()) - return None - - -def download_logs(storage_account, container, log_path_pattern, target_dir, az_path): - log_paths = az( - "storage", - "blob", - "download-batch", - "--destination", - target_dir, - "--source", - container, - "--account-name", - storage_account, - "--pattern", - log_path_pattern, - az_path=az_path, - ) - return [os.path.join(target_dir, log_path) for log_path in log_paths] - - -ImageInfo = collections.namedtuple("ImageInfo", ["repository", "digest"]) - - -def get_image_name(image_info): - return f"{image_info.repository}@{image_info.digest}" - - -timestamp_pattern = re.compile( - r"^(?P\d+)-(?P\d+)-(?P\d+)T(?P\d+):(?P\d+):(?P\d+)" -) - - -def parse_timestamp(timestamp_str): - match = timestamp_pattern.match(timestamp_str) - if match is None: - return None - - return datetime.datetime( - year=int(match["year"]), - month=int(match["month"]), - day=int(match["day"]), - hour=int(match["hour"]), - minute=int(match["minute"]), - second=int(match["second"]), - tzinfo=datetime.timezone.utc, - ) - - -def parse_log_line(line, min_datetime): - entry = json.loads(line) - - def check_time(value): - timestamp = parse_timestamp(value) - return timestamp is not None and timestamp >= min_datetime - - for field_name, expected_value_or_checker in [ - ("category", "ContainerRegistryRepositoryEvents"), - ("operationName", lambda value: value in ["Pull", "Push"]), - ("resultType", "HttpStatusCode"), - ("resultDescription", lambda value: value in ["200", "201"]), - ("time", check_time), - ]: - value = entry.get(field_name, "") - if callable(expected_value_or_checker): - if not expected_value_or_checker(value): - return None - else: - if value != expected_value_or_checker: - return None - - props = entry.get("properties", {}) - repo, digest = props.get("repository"), props.get("digest") - - if repo is None or digest is None: - return None - - return ImageInfo(repo, digest) - - -def get_valid_images_from_logs(log_paths, min_datetime, min_access_count): - image_counts = dict() # dict of {ImageInfo -> count} - - for log_path in log_paths: - log.debug(f"Processing log file: {log_path}") - with open(log_path) as log_file: - for line in log_file: - image_info = parse_log_line(line, min_datetime) - if image_info is not None: - image_counts[image_info] = image_counts.get(image_info, 0) + 1 - - return {image for image, count in image_counts.items() if count >= min_access_count} - - -def get_registry_images(container_registry, az_path): - registry_images = set() # set of ImageInfo - - repositories = az("acr", "repository", "list", "--name", container_registry, az_path=az_path) - - for repository in repositories: - digests = az( - "acr", - "repository", - "show-manifests", - "--repository", - repository, - "--name", - container_registry, - "--query", - "[*].digest", - az_path=az_path, - ) - - registry_images.update([ImageInfo(repository, digest) for digest in digests]) - - return registry_images - - -def clean_images(container_registry, image_names, az_path): - for image_name in image_names: - az( - "acr", - "repository", - "delete", - "--name", - container_registry, - "--image", - image_name, - "--yes", - az_path=az_path, - parse_output=False, - ) - - -# Note: -# the log download and parsing could be replaced by a log analytics query -""" -let cache_history = 7d; -let cache_min_access_count = 1; -ContainerRegistryRepositoryEvents -| where TimeGenerated >= ago(cache_history) -| where OperationName in ("Pull", "Push") -| where ResultDescription in ("200", "201") -| summarize AccessCount = count() by Repository, Digest -| where AccessCount >= cache_min_access_count -| project Repository, Digest -""" -# need to figure out how run the query the programmatically though - - -def main(): - args = parse_args() - - valid_images = set() - - with tempfile.TemporaryDirectory() as tmp_dir: - log_paths = download_logs( - args.log_storage_account, - args.log_storage_account_container, - args.log_storage_path_pattern, - tmp_dir, - args.az_path, - ) - - cache_history = datetime.timedelta(days=args.cache_history_days) - - min_timestamp = datetime.datetime.now(tz=datetime.timezone.utc) - cache_history - - valid_images = get_valid_images_from_logs(log_paths, min_timestamp, args.cache_min_access_count) - - all_images = get_registry_images(args.container_registry, args.az_path) - - def sorted_image_names(image_infos): - return sorted([get_image_name(image_info) for image_info in image_infos]) - - log.debug("All images:\n{}".format("\n".join(sorted_image_names(all_images)))) # noqa: G001 - log.debug("Valid images:\n{}".format("\n".join(sorted_image_names(valid_images)))) # noqa: G001 - - images_to_clean = all_images - valid_images - image_names_to_clean = sorted_image_names(images_to_clean) - - log.info("Images to clean:\n{}".format("\n".join(image_names_to_clean))) # noqa: G001 - - if args.dry_run: - log.info("Dry run, no images will be cleaned.") - return 0 - - clean_images(args.container_registry, image_names_to_clean, args.az_path) - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh b/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh index 88fb578c591b8..29c52404dc7e3 100755 --- a/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh +++ b/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh @@ -6,39 +6,57 @@ set -e set -x -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp312-cp312/bin:$PATH ls /build ls /build/deps + +# User inputs +USE_QNN=${1:-0} # by default qnn will not be included in package + # build the AAR package, using the build settings under /home/onnxruntimedev/.build_settings/ # if there is also include_ops_and_types.config exists in the same folder, use it to build with included ops/types -if [ -f "/home/onnxruntimedev/.build_settings/include_ops_and_types.config" ]; then - python3 /onnxruntime_src/tools/ci_build/github/android/build_aar_package.py \ - --build_dir /build \ - --config $BUILD_CONFIG \ - --android_sdk_path /android_home \ - --android_ndk_path /ndk_home \ - --include_ops_by_config /home/onnxruntimedev/.build_settings/include_ops_and_types.config \ - /home/onnxruntimedev/.build_settings/build_settings.json -else - python3 /onnxruntime_src/tools/ci_build/github/android/build_aar_package.py \ - --build_dir /build \ - --config $BUILD_CONFIG \ - --android_sdk_path /android_home \ - --android_ndk_path /ndk_home \ - /home/onnxruntimedev/.build_settings/build_settings.json + +BUILD_SCRIPT="/onnxruntime_src/tools/ci_build/github/android/build_aar_package.py" +BUILD_SETTINGS="/home/onnxruntimedev/.build_settings/build_settings.json" +INCLUDE_OPS_CONFIG="/home/onnxruntimedev/.build_settings/include_ops_and_types.config" + +ANDROID_SDK_HOME="/android_home" +ANDROID_NDK_HOME="/ndk_home" +QNN_HOME="/qnn_home" + + +# Base command for building the AAR package +COMMAND="python3 $BUILD_SCRIPT --build_dir /build --config $BUILD_CONFIG --android_sdk_path $ANDROID_SDK_HOME --android_ndk_path $ANDROID_NDK_HOME $BUILD_SETTINGS" + +# Check if the include ops config file exists and modify command if it does +if [ -f "$INCLUDE_OPS_CONFIG" ]; then + COMMAND+=" --include_ops_by_config $INCLUDE_OPS_CONFIG" fi +# Add qnn path to command +if [ "$USE_QNN" == "1" ]; then + if [ -d "$QNN_HOME" ]; then + COMMAND+=" --qnn_path $QNN_HOME" + else + echo "Error: QNN directory does not exist." + exit 1 + fi +fi + +# Execute the build command +eval $COMMAND + # Copy the built artifacts to give folder for publishing -BASE_PATH=/build/aar_out/${BUILD_CONFIG}/com/microsoft/onnxruntime/${PACKAGE_NAME}/${ORT_VERSION} -cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}-javadoc.jar /home/onnxruntimedev/.artifacts -cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}-sources.jar /home/onnxruntimedev/.artifacts -cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}.aar /home/onnxruntimedev/.artifacts -cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}.pom /home/onnxruntimedev/.artifacts +BASE_PATH=/build/aar_out/${BUILD_CONFIG}/com/microsoft/onnxruntime/${PACKAGE_NAME}/${ORT_VERSION}${RELEASE_VERSION_SUFFIX} +cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}${RELEASE_VERSION_SUFFIX}-javadoc.jar /home/onnxruntimedev/.artifacts +cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}${RELEASE_VERSION_SUFFIX}-sources.jar /home/onnxruntimedev/.artifacts +cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}${RELEASE_VERSION_SUFFIX}.aar /home/onnxruntimedev/.artifacts +cp ${BASE_PATH}/${PACKAGE_NAME}-${ORT_VERSION}${RELEASE_VERSION_SUFFIX}.pom /home/onnxruntimedev/.artifacts # Copy executables if necessary if [ "$PUBLISH_EXECUTABLES" == "1" ]; then pushd /build/intermediates/executables/${BUILD_CONFIG} - tar -czvf /home/onnxruntimedev/.artifacts/${PACKAGE_NAME}-${ORT_VERSION}-executables.tgz * + tar -czvf /home/onnxruntimedev/.artifacts/${PACKAGE_NAME}-${ORT_VERSION}${RELEASE_VERSION_SUFFIX}-executables.tgz * popd fi diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index 036db703ccf27..1b34b3d302e57 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -23,11 +23,11 @@ # Onnx Runtime native library is built against NDK API 21 by default # It is possible to build from source for Android API levels below 21, but it is not guaranteed -DEFAULT_ANDROID_MIN_SDK_VER = 21 +DEFAULT_ANDROID_MIN_SDK_VER = 24 # Android API 24 is the default target API version for Android builds, based on Microsoft 1CS requirements # It is possible to build from source using API level 21 and higher as the target SDK version -DEFAULT_ANDROID_TARGET_SDK_VER = 24 +DEFAULT_ANDROID_TARGET_SDK_VER = 34 def _parse_build_settings(args): @@ -79,6 +79,7 @@ def _build_aar(args): build_settings = _parse_build_settings(args) build_dir = os.path.abspath(args.build_dir) ops_config_path = os.path.abspath(args.include_ops_by_config) if args.include_ops_by_config else None + qnn_android_build = "--use_qnn" in build_settings["build_params"] # Setup temp environment for building temp_env = os.environ.copy() @@ -94,6 +95,26 @@ def _build_aar(args): base_build_command = [sys.executable, BUILD_PY] + build_settings["build_params"] + ["--config=" + build_config] header_files_path = "" + if qnn_android_build: + qnn_home = args.qnn_path + sdk_file = os.path.join(qnn_home, "sdk.yaml") + qnn_sdk_version = None + with open(sdk_file) as f: + for line in f: + if line.strip().startswith("version:"): + # yaml file has simple key: value format with version as key + qnn_sdk_version = line.split(":", 1)[1].strip() + break + + # Note: The QNN package version does not follow Semantic Versioning (SemVer) format. + # only use major.minor.patch version for qnn sdk version and truncate the build_id info if any + # yaml file typically has version like 2.26.0 + if qnn_sdk_version: + qnn_sdk_version = ".".join(qnn_sdk_version.split(".")[:3]) + base_build_command += ["--qnn_home=" + qnn_home] + else: + raise ValueError("Error: QNN SDK version not found in sdk.yaml file.") + # Build binary for each ABI, one by one for abi in build_settings["build_abis"]: abi_build_dir = os.path.join(intermediates_dir, abi) @@ -156,8 +177,13 @@ def _build_aar(args): if "--enable_training_apis" in build_settings["build_params"] else "-DENABLE_TRAINING_APIS=0" ), + "-DreleaseVersionSuffix=" + os.getenv("RELEASE_VERSION_SUFFIX", ""), ] + # Add qnn specific parameters + if qnn_android_build: + gradle_command.append(f"-DqnnVersion={qnn_sdk_version}") + # clean, build, and publish to a local directory subprocess.run([*gradle_command, "clean"], env=temp_env, shell=False, check=True, cwd=JAVA_ROOT) subprocess.run([*gradle_command, "build"], env=temp_env, shell=False, check=True, cwd=JAVA_ROOT) @@ -182,6 +208,8 @@ def parse_args(): "--android_ndk_path", type=str, default=os.environ.get("ANDROID_NDK_HOME", ""), help="Path to the Android NDK" ) + parser.add_argument("--qnn_path", type=str, default=os.environ.get("QNN_HOME", ""), help="Path to the QNN SDK") + parser.add_argument( "--build_dir", type=str, diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index b0eff75812673..1c7769c623d41 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -5,8 +5,8 @@ "x86", "x86_64" ], - "android_min_sdk_version": 21, - "android_target_sdk_version": 24, + "android_min_sdk_version": 24, + "android_target_sdk_version": 34, "build_params": [ "--enable_lto", "--android", diff --git a/tools/ci_build/github/android/default_qnn_aar_build_settings.json b/tools/ci_build/github/android/default_qnn_aar_build_settings.json new file mode 100644 index 0000000000000..599c108f830e7 --- /dev/null +++ b/tools/ci_build/github/android/default_qnn_aar_build_settings.json @@ -0,0 +1,19 @@ +{ + "build_abis": [ + "arm64-v8a" + ], + "android_min_sdk_version": 21, + "android_target_sdk_version": 24, + "build_params": [ + "--enable_lto", + "--android", + "--parallel", + "--cmake_generator=Ninja", + "--build_java", + "--build_shared_lib", + "--use_qnn", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", + "--skip_tests" + + ] +} diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 0c28b272f7fa3..4991b4329646f 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -4,30 +4,47 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |Operator|Note| |--------|------| |ai.onnx:Add|| +|ai.onnx:Argmax|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| +|ai.onnx:Cast|| |ai.onnx:Clip|| |ai.onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| |ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| |ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| |ai.onnx:Div|| +|ai.onnx:Erf|| |ai.onnx:Gemm|Input B must be constant.| +|ai.onnx:Gelu|| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GridSample|4D input.
'mode' of 'linear' or 'zeros'.
(mode==linear && padding_mode==reflection && align_corners==0) is not supported.| +|ai.onnx:GroupNormalization|| +|ai.onnx:InstanceNormalization|| +|ai.onnx:LayerNormalization|| |ai.onnx:LeakyRelu|| |ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.| |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| +|ai.onnx:Max|| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:PRelu|| |ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| +|ai.onnx:ReduceSum|| +|ai.onnx:ReduceMean|| +|ai.onnx:ReduceMax|| |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| +|ai.onnx:Round|| +|ai.onnx:Shape|| |ai.onnx:Slice|starts/ends/axes/steps must be constant initializers.| |ai.onnx:Split|If provided, `splits` must be constant.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| +|ai.onnx:Softmax|| |ai.onnx:Sqrt|| +|ai.onnx:Squeeze|| |ai.onnx:Tanh|| |ai.onnx:Transpose|| +|ai.onnx:Unsqueeze|| diff --git a/tools/ci_build/github/apple/get_simulator_device_info.py b/tools/ci_build/github/apple/get_simulator_device_info.py index 7de9aa13912e0..aa693038b4394 100755 --- a/tools/ci_build/github/apple/get_simulator_device_info.py +++ b/tools/ci_build/github/apple/get_simulator_device_info.py @@ -8,6 +8,7 @@ import functools import itertools import json +import os import subprocess @@ -37,7 +38,7 @@ def __lt__(self, other: Version) -> bool: def get_simulator_device_info( requested_runtime_platform: str = "iOS", requested_device_type_product_family: str = "iPhone", - max_runtime_version_str: str | None = None, + requested_runtime_version_str: str | None = None, ) -> dict[str, str]: """ Retrieves simulator device information from Xcode. @@ -45,11 +46,13 @@ def get_simulator_device_info( :param requested_runtime_platform: The runtime platform to select. :param requested_device_type_product_family: The device type product family to select. - :param max_runtime_version_str: The maximum runtime version to allow. + :param requested_runtime_version_str: The runtime version to select. If unspecified, selects the latest one. :return: A dictionary containing information about the selected simulator device. """ - max_runtime_version = Version(max_runtime_version_str) if max_runtime_version_str is not None else None + requested_runtime_version = ( + Version(requested_runtime_version_str) if requested_runtime_version_str is not None else None + ) simctl_proc = subprocess.run( ["xcrun", "simctl", "list", "--json", "--no-escape-slashes"], @@ -73,7 +76,7 @@ def runtime_filter(runtime) -> bool: if runtime["platform"] != requested_runtime_platform: return False - if max_runtime_version is not None and Version(runtime["version"]) > max_runtime_version: + if requested_runtime_version is not None and Version(runtime["version"]) != requested_runtime_version: return False return True @@ -108,6 +111,9 @@ def device_filter(device) -> bool: ): runtime_id_and_device_pairs.extend((runtime_id, device) for device in filter(device_filter, device_list)) + if len(runtime_id_and_device_pairs) == 0: + raise ValueError("Failed to find requested simulator device info.") + # sort key - tuple of (runtime version, device type min runtime version) # the secondary device type min runtime version value is to treat more recent device types as greater def runtime_id_and_device_pair_key(runtime_id_and_device_pair): @@ -137,13 +143,20 @@ def runtime_id_and_device_pair_key(runtime_id_and_device_pair): def main(): + requested_runtime_version_environment_variable_name = "ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION" + parser = argparse.ArgumentParser(description="Gets simulator info from Xcode and prints it in JSON format.") - parser.add_argument("--max-runtime-version", help="The maximum runtime version to allow.") + parser.add_argument( + "--requested-runtime-version", + default=os.environ.get(requested_runtime_version_environment_variable_name, None), + help="The requested runtime version. " + f"This may also be specified with the {requested_runtime_version_environment_variable_name} " + "environment variable. The command line option takes precedence. " + "An unspecified value means the latest available runtime version.", + ) args = parser.parse_args() - info = get_simulator_device_info( - max_runtime_version_str=args.max_runtime_version, - ) + info = get_simulator_device_info(requested_runtime_version_str=args.requested_runtime_version) print(json.dumps(info, indent=2)) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 7bc1cd669bbff..c3dbee336b69d 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP @@ -63,9 +63,6 @@ jobs: - script: | set -e -x - rm -rf /tmp/scripts - cp -r tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts /tmp - /tmp/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/installed -d cmake/deps.txt python3 tools/ci_build/build.py \ --config Release \ --android \ @@ -78,7 +75,7 @@ jobs: --use_qnn \ --qnn_home $(QnnSDKRootDir) \ --cmake_generator=Ninja \ - --skip_tests --path_to_protoc_exe $(Build.BinariesDirectory)/installed/bin/protoc + --skip_tests displayName: Build QNN EP - script: | diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 6fd02f6b59867..aca06c320d1d3 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -40,9 +40,8 @@ parameters: default: 0 variables: - - template: templates/common-variables.yml - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 - name: linux_trt_version value: 10.3.0.26-1.cuda11.8 - name: Repository @@ -116,15 +115,15 @@ stages: set -ex; \ env; \ ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator Ninja \ --config Release --update --build \ --skip_submodule_sync \ --build_shared_lib \ --parallel \ --build_wheel \ - --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ - --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_onnx_tests --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \ + --enable_cuda_profiling \ --enable_pybind --build_java \ --use_cache \ --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ @@ -181,6 +180,17 @@ stages: TargetPath: '$(Build.BinariesDirectory)/Release' SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py + DockerBuildArgs: " + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimeubuntupackagestest_cuda11 + UseImageCacheContainerRegistry: false + UpdateDepsTxt: false - task: Cache@2 inputs: @@ -197,14 +207,15 @@ stages: -v $(Build.BinariesDirectory)/Release:/Release \ -v $(STABLE_DIFFUSION_MODEL_CACHE):/model_cache:rw \ -v $(GenerateImage_DIR):/images:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' \ set -ex; \ python3 --version; \ python3 -m pip install --upgrade pip; \ python3 -m pip install /Release/*.whl; \ pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ - python3 -m pip install -r requirements-cuda11.txt; \ + python3 -m pip install -r requirements/cuda11/requirements.txt; \ + python3 -m pip install numpy==1.22.2; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon ; \ echo Generate an image guided by a text prompt; \ python3 demo_txt2img.py --framework-model-dir /model_cache --seed 1 --deterministic "astronaut riding a horse on mars" ; \ @@ -235,7 +246,7 @@ stages: - script: | docker run --rm --gpus all -v $PWD:/workspace \ -v $(CLIP_MODEL_CACHE):/model_cache:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' set -x; \ python3 --version; \ @@ -262,7 +273,7 @@ stages: - script: | docker run --rm --gpus all -v $PWD:/workspace \ -v $(CLIP_MODEL_CACHE):/model_cache:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' set -ex; \ python3 --version; \ @@ -270,6 +281,7 @@ stages: pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion/; \ image2=$(find $(pwd) -name "astronaut_riding_a_h*.png") ; \ pushd test; \ + python3 -m pip install numpy==1.22.2; \ python3 -m pip install -r requirements.txt; \ echo check demo_txt2image.py generate image; \ python3 -u check_image.py --image1 astronaut_riding_txt2image-DDIM-50.png --image2 $image2 --cache_dir /model_cache ; \ @@ -435,7 +447,7 @@ stages: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg Context: tools/ci_build/github/linux/docker/ ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: '--build-arg BUILD_UID=$( id -u )' diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 50d4d8a912585..4e5d9a70beb66 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -34,11 +34,8 @@ stages: # build Python packages # Linux GPU only - ${{ if parameters.BuildPythonPackages }}: - - template: templates/py-packaging-stage.yml + - template: stages/py-gpu-packaging-stage.yml parameters: enable_linux_gpu: true - enable_linux_cpu: false - enable_windows_cpu: false enable_windows_gpu: false - enable_mac_cpu: false - enable_linux_arm: false + cuda_version: 12.2 diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index e2d977bd60986..798868f3b957e 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.28.0.241029 resources: repositories: @@ -77,13 +77,14 @@ resources: ref: 5eda9aded5462201e6310105728d33016e637ea7 variables: +- template: templates/common-variables.yml - name: ReleaseVersionSuffix value: '' - name: win_trt_version value: 11.8 - name: win_trt_home - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }} - name: win_cuda_home value: $(Agent.TempDirectory)\v11.8 @@ -111,6 +112,7 @@ stages: BuildVariant: 'default' SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + QnnSDKVersion: ${{ parameters.QnnSdk }} - template: stages/java-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml b/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml deleted file mode 100644 index 43e668eef8d00..0000000000000 --- a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml +++ /dev/null @@ -1,44 +0,0 @@ -parameters: -- name: DryRun - type: boolean - default: false - displayName: "Do a dry-run and do not remove any images" -- name: CacheHistoryDays - type: number - default: 4 - displayName: "The length of the cache history in days" -- name: CacheMinAccessCount - type: number - default: 5 - displayName: "The minimum access count over the cache history" - -variables: - ${{ if eq(parameters.DryRun, true) }}: - DryRunArgument: "--dry-run" - -jobs: -- job: Clean_Build_Docker_Image_Cache - - pool: onnxruntime-Ubuntu2204-AMD-CPU - - timeoutInMinutes: 30 - - steps: - - checkout: self - submodules: false - fetchDepth: 1 - - - task: AzureCLI@2 - inputs: - azureSubscription: 'AIInfraBuild' - scriptType: 'bash' - scriptLocation: 'inlineScript' - inlineScript: | - tools/ci_build/clean_docker_image_cache.py \ - ${{ variables.DryRunArgument }} \ - --container-registry $(buildcache-container-registry) \ - --log-storage-account $(buildcache-log-storage-account) \ - --log-storage-account-container $(buildcache-log-storage-account-container) \ - --cache-history-days ${{ parameters.CacheHistoryDays }} \ - --cache-min-access-count ${{ parameters.CacheMinAccessCount }} - displayName: "Clean image cache" diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 7118e85e9ea4b..bc33aba57ec93 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -59,13 +59,14 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: ReleaseVersionSuffix value: '' - name: win_trt_home ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda12 }} - name: win_cuda_home ${{ if eq(parameters.CudaVersion, '11.8') }}: value: $(Agent.TempDirectory)\v11.8 @@ -97,7 +98,6 @@ stages: jobs: - template: templates/c-api-linux-cpu.yml parameters: - BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 78dafc0cd7dab..2eb2839cdac02 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -54,7 +54,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8x64 - template: templates/linux-build-step-with-cache.yml @@ -82,22 +82,20 @@ stages: onnxruntimecpubuildcentos8x64 \ /bin/bash -c ' set -ex; \ - python3.9 /onnxruntime_src/tools/ci_build/build.py \ + python3.12 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator 'Ninja' \ --config Debug \ --skip_submodule_sync \ --build_shared_lib \ --parallel --use_binskim_compliant_compile_flags \ - --build_csharp \ --enable_onnx_tests --enable_address_sanitizer \ --update --build; - LD_PRELOAD=/usr/lib64/libasan.so.6 python3.9 /onnxruntime_src/tools/ci_build/build.py \ + python3.12 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator 'Ninja' \ --config Debug \ --skip_submodule_sync \ --build_shared_lib \ --parallel --use_binskim_compliant_compile_flags \ - --build_csharp \ --enable_onnx_tests --enable_address_sanitizer \ --test; ' @@ -151,7 +149,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuild - task: PythonScript@0 @@ -219,7 +217,7 @@ stages: /bin/bash -c " set -ex; \ ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator 'Ninja' \ --config Release \ --skip_submodule_sync \ @@ -301,6 +299,7 @@ stages: machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' with_cache: true cmake_build_type: Release + python_exe_path: '/opt/python/cp310-cp310/bin/python3.10' - stage: arm64_test dependsOn: ['arm64_build'] @@ -308,4 +307,27 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' + ep: 'cpu' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + +- stage: arm64_build_xnnpack + dependsOn: [] + jobs: + - template: templates/py-linux.yml + parameters: + arch: 'aarch64' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + with_cache: true + cmake_build_type: Release + ep: 'XNNPack' + extra_build_arg: '--use_xnnpack' + python_exe_path: '/opt/python/cp310-cp310/bin/python3.10' + +- stage: arm64_test_xnnpack + dependsOn: ['arm64_build_xnnpack'] + jobs: + - template: templates/py-packaging-linux-test-cpu.yml + parameters: + arch: 'aarch64' + ep: 'XNNPack' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml deleted file mode 100644 index b88bad2fae8bb..0000000000000 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ /dev/null @@ -1,132 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- job: Linux_Build - timeoutInMinutes: 120 - workspace: - clean: all - variables: - CCACHE_DIR: $(Agent.TempDirectory)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu2204-AMD-CPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.aten_cpu - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: 'onnxruntimecpubuildaten' - UseImageCacheContainerRegistry: true - UsePipelineCache: false - - - template: templates/linux-build-step-with-cache.yml - parameters: - WithCache: true - Today: $(TODAY) - AdditionalKey: ort_aten - CacheDir: $(CCACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: CmdLine@2 - displayName: 'build' - inputs: - script: | - mkdir -p $HOME/.onnx - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(CCACHE_DIR):/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - onnxruntimecpubuildaten \ - /bin/bash -c " - set -ex; \ - ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ - --config Release \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --skip_tests \ - --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON \ - --use_cache; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 'install ort_torch_ext and launch test' - inputs: - script: | - mkdir -p $HOME/.onnx - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildaten \ - bash -c "rm -rf /build/Release/onnxruntime /build/Release/pybind11 && \ - /opt/python/cp38-cp38/bin/python3 -m pip install /build/Release/dist/*.whl && \ - /opt/python/cp38-cp38/bin/python3 -m pip install /onnxruntime_src/onnxruntime/python/torch_cpp_extensions && \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/onnxruntime/test/python/contrib_ops/aten_op_tests.py && \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ - --config Release \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --test \ - --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON" - workingDirectory: $(Build.SourcesDirectory) - - - template: templates/explicitly-defined-final-tasks.yml - - - script: | - df -h - displayName: check disk space diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml deleted file mode 100644 index 72ef660d4b344..0000000000000 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +++ /dev/null @@ -1,99 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -# This pipeline builds the latest PyTorch commit from source -# and use it in ORT tests. See Dockerfile.manylinux2014_lort_cpu -# for the installation steps. Idally, we should only use one pipeline -# for eager mode and LazyTensor, but we split them due to recent -# PyTorch's breaking changes. -# -# TODO: once ORT eager mode can run with latest PyTorch commit, we -# should -# 1. Set --build_eager_mode when running build.py in the -# first "task" below. -# 2. Copy the second "task" above as the third task below. -- job: BuildAndTestLazyTensor - timeoutInMinutes: 120 - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - steps: - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.lort_cpu - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildlort - - - task: CmdLine@2 - displayName: 'Build ORT for Python 3.9' - inputs: - script: | - docker run --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildlort \ - python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ - --config Release \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --enable_lazy_tensor --enable_training --build_wheel --skip_test \ - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 'Test DORT with Python 3.9' - inputs: - script: | - docker run --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildlort \ - bash -c " - export PYTHONPATH=/build/Release && \ - python3 -m pip install /build/Release/dist/*.whl && \ - python3 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_dort.py && \ - cd /build/Release && python3 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py" - workingDirectory: $(Build.SourcesDirectory) - condition: succeededOrFailed() - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml index 2d3260a13f13a..4964d33067092 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml @@ -128,7 +128,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/2 --cmake_generator Ninja \ --config Debug \ --skip_submodule_sync \ @@ -210,7 +210,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/5 --cmake_generator Ninja \ --config Debug \ --skip_submodule_sync \ @@ -231,7 +231,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6a \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -258,7 +258,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6b \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -287,7 +287,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6c \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -317,7 +317,7 @@ jobs: -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ -e NIGHTLY_BUILD \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/7 \ --cmake_generator Ninja \ --config MinSizeRel \ diff --git a/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml index 7311c6e526d57..0391ecf4f5869 100644 --- a/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml @@ -49,6 +49,7 @@ jobs: Repository: onnxruntimecpubuild - task: CmdLine@2 + displayName: 'Build and test' inputs: script: | mkdir -p $HOME/.onnx @@ -61,7 +62,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build \ --config Debug Release \ --skip_submodule_sync \ diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 84e953366f9fa..7bb1deb60c6ba 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -49,9 +49,9 @@ parameters: variables: - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: Repository ${{ if eq(parameters.CudaVersion, '11.8') }}: @@ -175,13 +175,14 @@ stages: -e NVIDIA_TF32_OVERRIDE=0 \ $(Repository) \ /bin/bash -c ' + set -e nvidia-smi; \ /sbin/ldconfig -N -v $(sed "s/:/ /" <<< $LD_LIBRARY_PATH) 2>/dev/null | grep -E "libcudart.so|libcudnn.so|libnvinfer.so"; \ cat /usr/local/cuda/include/cuda.h | grep -m1 CUDA_VERSION; \ cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -m1 -A 2; \ - ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ - /tmp/python3 -m pip install /build/Release/dist/*.whl; \ - /tmp/python3 -u -c "from onnxruntime.capi._pybind_state import (OrtDevice as C_OrtDevice) ; \ + export PATH=/opt/python/cp312-cp312/bin:$PATH; \ + python3 -m pip install /build/Release/dist/*.whl; \ + python3 -u -c "from onnxruntime.capi._pybind_state import (OrtDevice as C_OrtDevice) ; \ ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0); \ print(ort_device); print(ort_device.device_type(), C_OrtDevice.cuda()); \ assert(ort_device.device_type()==1); assert(C_OrtDevice.cuda()==1);" \ @@ -204,13 +205,13 @@ stages: /bin/bash -c ' set -ex; \ cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ - ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ - /tmp/python3 -m pip install -r /tmp/requirements.txt; \ - /tmp/python3 -m pip install /build/Release/dist/*.whl; \ + export PATH=/opt/python/cp312-cp312/bin:$PATH; \ + python3 -m pip install -r /tmp/requirements.txt; \ + python3 -m pip install /build/Release/dist/*.whl; \ cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ cd /tmp; \ - /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ + python3 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ --enable_transformers_tool_test --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ --enable_pybind --build_java --ctest_path "" ; \ diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 6717e8b4faaa2..9d60c9ea17cd8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -37,16 +37,17 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index fb2c86dbf68e3..83cf26614a285 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,12 +8,12 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.4.cuda_12_5_cudnn_9 + default: 10.5.cuda_12_5_cudnn_9 values: - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 - - 10.4.cuda_11_8_cudnn_8 - - 10.4.cuda_12_5_cudnn_9 + - 10.5.cuda_11_8_cudnn_8 + - 10.5.cuda_12_5_cudnn_9 - BIN - name: UseTensorrtOssParser diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index 1cf60b47b4ded..9e2d8e49a2292 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -37,9 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.1 - - name: RocmVersionPatchSuffix - value: ".3" + value: 6.2.3 jobs: - job: Linux_Build @@ -66,7 +64,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - task: Cache@2 @@ -165,7 +163,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - task: CmdLine@2 diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index d8c0120fc9ee5..c7b814f3dd52c 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,5 +33,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.3.0 -x "--use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.5.0 -x "--use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 02566c3c73954..009daebea165a 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index 50f3862761320..c730cc2548038 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -37,9 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.1 - - name: RocmVersionPatchSuffix - value: ".3" + value: 6.1.3 jobs: - job: Linux_Build @@ -66,7 +64,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion) - task: Cache@2 @@ -166,7 +164,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion) - task: CmdLine@2 @@ -231,7 +229,11 @@ jobs: -e KERNEL_EXPLORER_TEST_USE_CUPY=1 \ -e CUPY_CACHE_DIR=/build/Release \ onnxruntimerocm-cibuild-rocm$(RocmVersion) \ - pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100 + /bin/bash -c " + set -ex; \ + python --version; \ + ls /opt/miniconda/envs/rocm-ci/lib/; \ + pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100" workingDirectory: $(Build.SourcesDirectory) displayName: 'Run kernel explorer tests' condition: succeededOrFailed() diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index c61beb63b8b40..9576aac182bbe 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -36,9 +36,16 @@ jobs: PROTO_CACHE_DIR: $(Pipeline.Workspace)/proto_ccache ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + # Note: Keep the Xcode version and iOS simulator version compatible. + # Check the table here to see what iOS simulator versions are supported by a particular Xcode version: + # https://developer.apple.com/support/xcode/ + XCODE_VERSION: 14.3.1 + IOS_SIMULATOR_RUNTIME_VERSION: 16.4 timeoutInMinutes: 150 steps: - template: templates/use-xcode-version.yml + parameters: + xcodeVersion: $(XCODE_VERSION) - template: templates/mac-build-step-with-cache.yml parameters: @@ -71,3 +78,4 @@ jobs: CCACHE_DEPEND: 1 CCACHE_SLOPPINESS: modules CCACHE_DIR: $(ORT_CACHE_DIR) + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: $(IOS_SIMULATOR_RUNTIME_VERSION) diff --git a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml new file mode 100644 index 0000000000000..c6ab33164035c --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml @@ -0,0 +1,317 @@ +trigger: none + +variables: + DisableDockerDetector: true + +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +extends: + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + pool: + name: onnxruntime-Win-CPU-2022 + os: windows + sdl: + git: + submodules: false + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + stages: + - stage: Windows_Build + jobs: + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: x64 + + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: x86 + + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: arm64 + + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: x64 + Runtime: static + + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: x86 + Runtime: static + + - template: tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml@self + parameters: + BuildArch: arm64 + Runtime: static + + - job: NuGet_Packaging + dependsOn: + - Windows_Packaging_x64_dynamic + - Windows_Packaging_x86_dynamic + - Windows_Packaging_arm64_dynamic + - Windows_Packaging_x64_static + - Windows_Packaging_x86_static + - Windows_Packaging_arm64_static + condition: succeeded() + templateContext: + inputs: + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_x64_dynamic + targetPath: $(Build.BinariesDirectory)/nuget-artifact-x64 + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_x86_dynamic + targetPath: $(Build.BinariesDirectory)/nuget-artifact-x86 + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_arm64_dynamic + targetPath: $(Build.BinariesDirectory)/nuget-artifact-arm64 + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_x64_static + targetPath: $(Build.BinariesDirectory)/nuget-artifact-x64-static-runtime + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_x86_static + targetPath: $(Build.BinariesDirectory)/nuget-artifact-x86-static-runtime + - input: pipelineArtifact + artifactName: drop_Windows_Build_Windows_Packaging_arm64_static + targetPath: $(Build.BinariesDirectory)/nuget-artifact-arm64-static-runtime + outputs: + - output: pipelineArtifact + path: '$(Build.ArtifactStagingDirectory)/merged' + artifact: drop_Windows_Build_NuGet_Packaging + + steps: + - task: PowerShell@2 + displayName: 'Bundle NuGet and other binaries' + inputs: + targetType: 'inline' + script: | + Add-Type -AssemblyName "System.IO.Compression.FileSystem" + + $nupkgs = (Get-ChildItem -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $x64_nuget_package_name = $nupkgs[0].Name + $x64_nuget_package = $nupkgs[0].FullName + $x64_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x64_nupkg_unzipped_directory = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($x64_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x64_nuget_package, $x64_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-x64-static-runtime -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $x64_static_runtime_nuget_package = $nupkgs[0].FullName + $x64_static_runtime_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x64_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($x64_static_runtime_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x64_static_runtime_nuget_package, $x64_static_runtime_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-x86 -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $x86_nuget_package = $nupkgs[0].FullName + $x86_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x86_nupkg_unzipped_directory = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($x86_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x86_nuget_package, $x86_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-x86-static-runtime -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $x86_static_runtime_nuget_package = $nupkgs[0].FullName + $x86_static_runtime_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x86_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($x86_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($x86_static_runtime_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x86_static_runtime_nuget_package, $x86_static_runtime_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-arm64 -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $arm64_nuget_package = $nupkgs[0].FullName + $arm64_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $arm64_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm64_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_nuget_package, $arm64_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-arm64-static-runtime -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) + $arm64_static_runtime_nuget_package = $nupkgs[0].FullName + $arm64_static_runtime_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $arm64_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm64_static_runtime_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_static_runtime_nuget_package, $arm64_static_runtime_nupkg_unzipped_directory) + + + + $x64_static_runtime_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native') + $x64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native', 'static') + $x86_runtime_path_old = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') + $x86_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') + $x86_static_runtime_path_old = [System.IO.Path]::Combine($x86_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') + $x86_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native', 'static') + $arm64_runtime_path_old = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') + $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') + $arm64_static_runtime_path_old = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') + $arm64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native', 'static') + + $uap_build_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'build', 'native') + $uap_build_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'build', 'uap10.0') + + New-Item -Path $x64_static_runtime_path_new -ItemType Directory + New-Item -Path $x86_runtime_path_new -ItemType Directory + New-Item -Path $x86_static_runtime_path_new -ItemType Directory + New-Item -Path $arm64_runtime_path_new -ItemType Directory + New-Item -Path $arm64_static_runtime_path_new -ItemType Directory + + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.dll')) $x86_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.lib')) $x86_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $x86_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $x86_runtime_path_new + + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'onnxruntime.dll')) $arm64_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'onnxruntime.lib')) $arm64_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $arm64_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $arm64_runtime_path_new + + Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.dll')) + Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.lib')) + Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) + Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) + + Copy-Item ([System.IO.Path]::Combine($x86_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($x86_static_runtime_path_new, 'onnxruntime.dll')) + Copy-Item ([System.IO.Path]::Combine($x86_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($x86_static_runtime_path_new, 'onnxruntime.lib')) + Copy-Item ([System.IO.Path]::Combine($x86_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($x86_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) + Copy-Item ([System.IO.Path]::Combine($x86_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($x86_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) + + Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'onnxruntime.dll')) + Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'onnxruntime.lib')) + Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) + Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) + + Copy-Item -Recurse $uap_build_path_old $uap_build_path_new + + $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') + if (!(Test-Path $merged_nuget_path)) { + New-Item -Path $merged_nuget_path -ItemType Directory + } + + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $x64_nuget_package_name) + Start-Process -FilePath "7z" -ArgumentList "-tzip a -r $merged_nuget ." -WorkingDirectory $x64_nupkg_unzipped_directory -NoNewWindow -Wait + + workingDirectory: $(Build.BinariesDirectory)\nuget-artifact-x64 + + - task: PowerShell@2 + displayName: 'Bundle Symbols NuGet' + inputs: + targetType: 'inline' + script: | + Add-Type -AssemblyName "System.IO.Compression.FileSystem" + + $nupkgs = (Get-ChildItem -Filter Microsoft.AI.MachineLearning*.snupkg -Recurse) + $x64_nuget_package_name = $nupkgs[0].Name + $x64_nuget_package = $nupkgs[0].FullName + $x64_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x64_nupkg_unzipped_directory = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($x64_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x64_nuget_package, $x64_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-x86 -Filter Microsoft.AI.MachineLearning*.snupkg -Recurse) + $x86_nuget_package = $nupkgs[0].FullName + $x86_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $x86_nupkg_unzipped_directory = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($x86_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($x86_nuget_package, $x86_nupkg_unzipped_directory) + + $nupkgs = (Get-ChildItem ..\nuget-artifact-arm64 -Filter Microsoft.AI.MachineLearning*.snupkg -Recurse) + $arm64_nuget_package = $nupkgs[0].FullName + $arm64_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName + $arm64_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($arm64_nuget_package)) + [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_nuget_package, $arm64_nupkg_unzipped_directory) + + $x86_runtime_path_old = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') + $x86_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') + $arm64_runtime_path_old = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') + $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') + + New-Item -Path $x86_runtime_path_new -ItemType Directory + New-Item -Path $arm64_runtime_path_new -ItemType Directory + + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.pdb')) $x86_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $x86_runtime_path_new + + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'onnxruntime.pdb')) $arm64_runtime_path_new + Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $arm64_runtime_path_new + + $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') + if (!(Test-Path $merged_nuget_path)) { + New-Item -Path $merged_nuget_path -ItemType Directory + } + + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $x64_nuget_package_name) + + Start-Process -FilePath "7z" -ArgumentList "-tzip a -r $merged_nuget ." -WorkingDirectory $x64_nupkg_unzipped_directory -NoNewWindow -Wait + + $merged_nuget_without_pdb = [System.IO.Path]::ChangeExtension($merged_nuget, '.nupkg') + + # Now we combine the DLLs and PDBs together, put them back in a folder under $(Build.SourcesDirectory) + # We won't upload the unzipped folder. We will just feed it to BinSkim. + 7z x -o$(Build.SourcesDirectory)\unzipped $merged_nuget + 7z -y x -o$(Build.SourcesDirectory)\unzipped $merged_nuget_without_pdb + + workingDirectory: $(Build.BinariesDirectory)\nuget-artifact-x64 + + - script: | + dir $(Build.SourcesDirectory)\unzipped\runtimes\win-x64\_native + + - task: SFP.build-tasks.custom-build-task-1.EsrpCodeSigning@5 + displayName: "Sign Nuget package" + inputs: + ConnectedServiceName: 'OnnxrunTimeCodeSign_20240611' + AppRegistrationClientId: '53d54d02-978d-4305-8572-583cf6711c4f' + AppRegistrationTenantId: '72f988bf-86f1-41af-91ab-2d7cd011db47' + AuthAKVName: 'buildkeyvault' + AuthCertName: '53d54d02-SSL-AutoRotate' + AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' + + FolderPath: $(Build.ArtifactStagingDirectory) + Pattern: '*.nupkg' + SessionTimeout: 90 + ServiceEndpointUrl: 'https://api.esrp.microsoft.com/api/v2' + MaxConcurrency: 25 + signConfigType: inlineSignParams + inlineOperation: | + [ + { + "keyCode": "CP-401405", + "operationSetCode": "NuGetSign", + "parameters": [ ], + "toolName": "sign", + "toolVersion": "6.2.9304.0" + }, + { + "keyCode": "CP-401405", + "operationSetCode": "NuGetVerify", + "parameters": [ ], + "toolName": "sign", + "toolVersion": "6.2.9304.0" + } + ] + + - job: NuGet_Publishing + dependsOn: + - NuGet_Packaging + condition: succeeded() + templateContext: + inputs: + - input: pipelineArtifact + artifactName: drop_Windows_Build_NuGet_Packaging + targetPath: $(Build.BinariesDirectory)/merged + outputs: + - output: nuget + # condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) # Optional condition + useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. + packagesToPush: '$(Build.BinariesDirectory)/packages/*.nupkg;!$(Build.BinariesDirectory)/packages/*.symbols.nupkg' + packageParentPath: '$(Build.BinariesDirectory)/' + publishVstsFeed: PublicPackages/ORT-Nightly # Required when pushing to internal feed. + nuGetFeedType: internal # Change to external when publishing to external feed + allowPackageConflicts: true # Optional. NuGetCommand task only. + publishPackageMetadata: true # Optional + steps: + - powershell: | + Rename-Item -Path merged packages + + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Rename nuget folder' diff --git a/tools/ci_build/github/azure-pipelines/nuget/nuget_config/nuget.config b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/nuget.config new file mode 100644 index 0000000000000..f654900ad04d1 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/nuget.config @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x64/packages.config b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x64/packages.config new file mode 100644 index 0000000000000..294bd926a34cb --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x64/packages.config @@ -0,0 +1,6 @@ + + + + + + diff --git a/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x86/packages.config b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x86/packages.config new file mode 100644 index 0000000000000..3528545dfb06e --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/nuget/nuget_config/x86/packages.config @@ -0,0 +1,6 @@ + + + + + + diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index c73cbb102a640..29b16c47bca5d 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -75,7 +75,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' @@ -85,7 +85,7 @@ stages: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: ${{ parameters.BuildArch }} diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index b1e5816fb748e..f9ecfb7cf7938 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -23,6 +23,7 @@ stages: pool: ${{ parameters.AgentPool }} variables: + - template: ../../templates/common-variables.yml - name: OnnxRuntimeBuildDirectory value: '$(Build.BinariesDirectory)' @@ -52,7 +53,7 @@ stages: inputs: script: | ln -sf /data/models $(Build.BinariesDirectory) - + # As for Debian installation, replace '-1.' by '-1+' when assigning trt version below - ${{if contains(parameters.StageSuffix , 'GPU') }}: - template: ../../templates/get-docker-image-steps.yml parameters: @@ -61,7 +62,7 @@ stages: ${{ if eq(parameters.CudaVersion, '12.2') }}: DockerBuildArgs: " --build-arg BASEIMAGE=nvidia/cuda:12.2.2-devel-ubuntu20.04 - --build-arg TRT_VERSION=10.4.0.26-1+cuda12.6 + --build-arg TRT_VERSION=${{ replace(variables.linux_trt_version_cuda12, '-1.', '-1+') }} --build-arg BUILD_UID=$( id -u ) " ${{ else }}: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml index ddcea447adc94..4842fcbd4dcfb 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml @@ -40,7 +40,7 @@ stages: steps: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: x64 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml deleted file mode 100644 index 04462a60776d7..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ /dev/null @@ -1,114 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- job: Linux_Build - timeoutInMinutes: 180 - workspace: - clean: all - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu-2204-Training-CPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.8' - addToPath: true - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpythonx86_64 - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cach Task - - - task: CmdLine@2 - displayName: 'build' - inputs: - script: | - set -e -x - mkdir -p $HOME/.onnx - mkdir -p $(Pipeline.Workspace)/ccache - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume /data/models:/build/models:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(Pipeline.Workspace)/ccache:/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - onnxruntimecpubuildpythonx86_64 \ - /bin/bash -c " - set -ex; \ - ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator 'Unix Makefiles' \ - --config Release \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --enable_onnx_tests \ - --enable_training \ - --use_cache; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - - - task: PublishTestResults@2 - displayName: 'Publish unit test results' - inputs: - testResultsFiles: '**/*.results.xml' - searchFolder: '$(Build.BinariesDirectory)' - testRunTitle: 'Unit Test Run' - condition: succeededOrFailed() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml deleted file mode 100644 index 494035637a79d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ /dev/null @@ -1,55 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- template: templates/linux-ci.yml - parameters: - AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' - JobName: 'Onnxruntime_Linux_GPU_Training' - RunDockerBuildArgs: > - -o ubuntu20.04 -d gpu - -t onnxruntime_orttraining_ortmodule_tests_image - -u - -e - -x " - --enable_training - --config Release - --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 - --build_wheel - --enable_nvtx_profile - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 - " - RunInjectedPipeline: 'true' - InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' - DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - TimeoutInMinutes: 190 - # Enable unreleased onnx opsets in CI builds - # This facilitates testing the implementation for the new opsets - AllowReleasedOpsetOnly: '0' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml deleted file mode 100644 index dcbee286136f0..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml +++ /dev/null @@ -1,135 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -stages: -- stage: ORTModuleDistributedTest - dependsOn: [] - jobs: - - job: Onnxruntime_Linux_GPU_ORTModule_Distributed_Test - - timeoutInMinutes: 120 - pool: 'Onnxruntime-Linux-GPU-NC24sv3' - - steps: - - checkout: self - clean: true - submodules: recursive - - - template: templates/jobs/download_training_test_data.yml - - - template: templates/run-docker-build-steps.yml - parameters: - RunDockerBuildArgs: | - -o ubuntu20.04 -d gpu \ - -t onnxruntime_ortmodule_distributed_tests_image \ - -x " \ - --config RelWithDebInfo \ - --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \ - --enable_training \ - --update --build \ - --build_wheel \ - " \ - -m \ - -u \ - -e - DisplayName: 'Build' - - # Entry point for all ORTModule distributed tests - # Refer to orttraining/orttraining/test/python/how_to_add_ortmodule_distributed_ci_pipeline_tests.md for guidelines on how to add new tests to this pipeline. - - script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(Agent.TempDirectory)/mnist:/mnist \ - onnxruntime_ortmodule_distributed_tests_image \ - bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && python3 -m pip install torch==2.3.1+cu118 --index-url https://download.pytorch.org/whl/cu118 && python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install && echo temporarily skip /build/RelWithDebInfo/launch_test.py --cmd_line_with_args 'python orttraining_ortmodule_distributed_tests.py --mnist /mnist' --cwd /build/RelWithDebInfo" \ - displayName: 'Run orttraining_ortmodule_distributed_tests.py' - condition: succeededOrFailed() - timeoutInMinutes: 30 - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - - template: templates/clean-agent-build-directory-step.yml - -- stage: DistributedInferenceTest - dependsOn: [] - jobs: - - job: Onnxruntime_Linux_GPU_Inference_Distributed_Test - - timeoutInMinutes: 120 - pool: 'Onnxruntime-Linux-GPU-NC24sv3' - - steps: - - checkout: self - clean: true - submodules: recursive - - - template: templates/run-docker-build-steps.yml - parameters: - RunDockerBuildArgs: | - -o ubuntu20.04 -d gpu \ - -t onnxruntime_ortmodule_distributed_tests_image \ - -x " \ - --config RelWithDebInfo \ - --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \ - --update --build \ - --build_wheel \ - --use_mpi \ - --enable_nccl \ - " \ - -m \ - -u \ - -e - DisplayName: 'Build' - - - script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /mnist:/mnist \ - onnxruntime_ortmodule_distributed_tests_image \ - bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py onnxscript && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_distributed.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py" \ - displayName: 'Run onnxruntime_test_collective.py, onnxruntime_test_distributed.py and test_sharded_moe.py' - condition: succeededOrFailed() - timeoutInMinutes: 30 - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml deleted file mode 100644 index e13ef9160bed3..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-nightly-ortmodule-test-pipeline.yml +++ /dev/null @@ -1,33 +0,0 @@ -trigger: none - -jobs: -- job: Onnxruntime_Linux_Nightly_ORTModule_tests - - timeoutInMinutes: 120 - pool: 'Onnxruntime-Linux-GPU-NC6sv3' - - steps: - - checkout: self - clean: true - submodules: recursive - - # Entry point for all ortmodule training tests - - script: | - COMMIT_ID=$(python3 -c "import onnxruntime; print(onnxruntime.get_build_info().split('git-commit-id=')[1].split(',')[0])") - cd $(Build.SourcesDirectory) - git checkout $COMMIT_ID - git branch - echo "Retrieved ONNX Runtime Commit ID: $COMMIT_ID" - docker run \ - --gpus all \ - --rm \ - --volume $(Build.SourcesDirectory)/orttraining/orttraining/test/python:/onnxruntime_src \ - --volume $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly:/requirements_torch_nightly \ - ptebic.azurecr.io/internal/aifx/acpt/nightly-ubuntu-cuda-torch-dev \ - bash -c "python3 -m pip install -r /requirements_torch_nightly/requirements.txt && ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py && ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 python3 -m pytest -sv /onnxruntime_src/orttraining_test_ortmodule_api.py" - displayName: 'Run ORTModule Tests' - condition: succeededOrFailed() - timeoutInMinutes: 120 - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml deleted file mode 100644 index ec5c30787b611..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml +++ /dev/null @@ -1,37 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -stages: -- template: templates/mac-cpu-packaging-pipeline.yml - parameters: - AllowReleasedOpsetOnly: 0 - BuildForAllArchs: false - AdditionalBuildFlags: --enable_training --build_objc - WithCache: true diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml deleted file mode 100644 index 2c520a25cb39e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ /dev/null @@ -1,405 +0,0 @@ -trigger: none - -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - 'js/web' - - 'onnxruntime/core/providers/js' -name: 'orttraining_ci_$(Date:yyyyMMdd)_$(Rev:r)' - -resources: - repositories: - - repository: manylinux - type: Github - endpoint: Microsoft - name: pypa/manylinux - ref: 5eda9aded5462201e6310105728d33016e637ea7 - -variables: - - name: video - value: 44 - - name: render - value: 109 - - name: RocmVersion - value: 6.1 - - name: RocmVersionPatchSuffix - value: ".3" - - name: BuildConfig - value: Release - -jobs: -- job: Linux_Build_manylinux - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - timeoutInMinutes: 240 - - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg BUILD_UID=$(id -u) - --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build - CheckOutManyLinux: true - - - task: Cache@2 - inputs: - key: '"manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" - "manylinux" | "$(TODAY)" | - displayName: Cache Task - - - script: mkdir -p $(CCACHE_DIR) - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - task: CmdLine@2 - inputs: - script: |- - export ROCM_HOME=/opt/rocm - docker run --rm \ - --ipc=host \ - --network=host \ - --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - -e CC=/opt/rh/gcc-toolset-12/root/usr/bin/cc -e CXX=/opt/rh/gcc-toolset-12/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \ - -e CCACHE_DIR=/cache \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(CCACHE_DIR):/cache \ - --workdir /onnxruntime_src \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build \ - /bin/bash -c " - set -ex; \ - ccache -s; \ - /opt/python/cp39-cp39/bin/python3 tools/ci_build/build.py \ - --config $(BuildConfig) \ - --enable_training \ - --mpi_home /opt/ompi \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=${ROCM_HOME}/llvm/bin/clang++ \ - onnxruntime_BUILD_UNIT_TESTS=OFF \ - FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ - --use_cache \ - --use_rocm \ - --use_migraphx \ - --rocm_version=$(RocmVersion) \ - --rocm_home ${ROCM_HOME} \ - --nccl_home ${ROCM_HOME}\ - --update \ - --build_dir /build \ - --build \ - --parallel \ - --build_wheel \ - --skip_submodule_sync \ - --skip_tests; \ - ccache -sv; \ - ccache -z" - displayName: 'Build onnxruntime' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Build_ubuntu - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - timeoutInMinutes: 240 - - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" - Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-build - - #- script: |- - # sed -i 's|session_options.use_deterministic_compute = False|session_options.use_deterministic_compute = True|g' \ - # orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py - # displayName: 'Toggle ON deterministic compute mode for ORTModule' - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cache Task - - - script: mkdir -p $(CCACHE_DIR) - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - task: CmdLine@2 - inputs: - script: |- - export ROCM_HOME=/opt/rocm - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(CCACHE_DIR):/cache \ - -e CCACHE_DIR=/cache \ - --workdir /onnxruntime_src \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-build \ - /bin/bash -c " - set -ex; \ - ccache -s; \ - python tools/ci_build/build.py \ - --config $(BuildConfig) \ - --enable_training \ - --mpi_home /opt/ompi \ - --cmake_extra_defines \ - CMAKE_HIP_COMPILER=${ROCM_HOME}/llvm/bin/clang++ \ - onnxruntime_BUILD_KERNEL_EXPLORER=ON \ - --use_cache \ - --use_rocm \ - --rocm_version=$(RocmVersion) \ - --rocm_home ${ROCM_HOME} \ - --nccl_home ${ROCM_HOME}\ - --update \ - --build_dir /build \ - --build \ - --parallel \ - --build_wheel \ - --skip_submodule_sync \ - --skip_tests; \ - ccache -sv; \ - ccache -z" - displayName: 'Build onnxruntime' - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test_ubuntu - workspace: - clean: all - pool: AMD-GPU - dependsOn: - - Linux_Build_ubuntu - timeoutInMinutes: 240 - - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - buildType: 'current' - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - checkout: self - clean: true - submodules: recursive - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" - Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test - - - task: Bash@3 - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh - arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER - displayName: 'Check ROCm Environment' - - # TODO: move to use ci_build/build.py driven tests - - task: CmdLine@2 - inputs: - script: |- - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user onnxruntimedev \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e OPENBLAS_NUM_THREADS=1 \ - -e OPENMP_NUM_THREADS=1 \ - -e MKL_NUM_THREADS=1 \ - -e PYTHONPATH=/build/$(BuildConfig) \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ - /bin/bash -c " - set -ex; \ - pip install -r /onnxruntime_src/tools/ci_build/requirements/transformers-test/requirements.txt; \ - pytest /onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py -v -n 4 --reruns 1" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run tranformers tests' - condition: succeededOrFailed() - - - task: CmdLine@2 - inputs: - script: |- - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user onnxruntimedev \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --workdir /build/$(BuildConfig) \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ - /bin/bash -c " - set -ex; \ - chmod a+x /build/Release/onnxruntime_test_all; \ - /onnxruntime_src/tools/ci_build/github/pai/pai_test_launcher.sh" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run onnxruntime unit tests' - condition: succeeded() - - - task: CmdLine@2 - inputs: - script: |- - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user onnxruntimedev \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e OPENBLAS_NUM_THREADS=1 \ - -e OPENMP_NUM_THREADS=1 \ - -e MKL_NUM_THREADS=1 \ - -e KERNEL_EXPLORER_BUILD_DIR=/build/$(BuildConfig) \ - -e KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE=8 \ - -e KERNEL_EXPLORER_TEST_USE_CUPY=1 \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ - pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100 - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run kernel explorer tests' - condition: succeededOrFailed() - - - task: CmdLine@2 - inputs: - script: |- - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user onnxruntimedev \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --workdir /build/$(BuildConfig) \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ - /bin/bash -c " - set -ex; \ - export PYTHONPATH=/build/$(BuildConfig); \ - python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install; \ - bash /onnxruntime_src/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh -v $(RocmVersion)" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run Python Hugging-Face BERT-L test' - condition: succeededOrFailed() - - - # Entry point for all ORTModule tests - # The onnxruntime folder is deleted in the build directory - # to enforce use of the onnxruntime wheel - - task: CmdLine@2 - inputs: - script: |- - rm -rf $(Build.BinariesDirectory)/$(BuildConfig)/onnxruntime/ - files=($(Build.BinariesDirectory)/$(BuildConfig)/dist/*.whl) - echo ${files[0]} - whlfilename=$(basename ${files[0]}) - echo $whlfilename - docker run --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --device=/dev/kfd \ - --device=/dev/dri/renderD$DRIVER_RENDER \ - --group-add $(video) \ - --group-add $(render) \ - --user onnxruntimedev \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --workdir /build/$(BuildConfig) \ - onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ - /bin/bash -c " - set -ex; \ - unset PYTHONPATH; \ - pip install /build/$(BuildConfig)/dist/$whlfilename; \ - python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install; \ - mkdir /home/onnxruntimedev/mnist /home/onnxruntimedev/bert_data; \ - export ORTMODULE_DISABLE_CPU_TRAINING_TEST=1; \ - export ORTMODULE_ROCM_TEST=1; \ - python orttraining_ortmodule_tests.py \ - --mnist /home/onnxruntimedev/mnist \ - --bert_data /home/onnxruntimedev/bert_data/hf_data/glue_data/CoLA/original/raw" - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run orttraining_ortmodule_tests.py' - condition: succeededOrFailed() - - - - task: Bash@3 - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh - arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER - displayName: 'Clean ROCm Environment' - condition: always() - - - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml deleted file mode 100644 index a71e10f95f3e1..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ /dev/null @@ -1,28 +0,0 @@ -parameters: -- name: cmake_build_type - type: string - displayName: 'Linux/Windows/iOS packages cmake build type.' - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -trigger: none - -stages: -- template: templates/py-packaging-stage.yml - parameters: - build_py_parameters: --enable_training - cmake_build_type: ${{ parameters.cmake_build_type }} - enable_linux_gpu: false - enable_linux_cpu: true - enable_windows_cpu: true - enable_windows_gpu: false - enable_mac_cpu: true - enable_linux_arm: false - enable_windows_arm64_qnn: false - enable_windows_arm64ec_qnn: false - enable_windows_x64_qnn: false - enable_linux_x64_qnn: false diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml deleted file mode 100644 index be3f67ba450b4..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ /dev/null @@ -1,28 +0,0 @@ -trigger: none - -parameters: - - name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - - - name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - -stages: -- template: templates/py-packaging-training-cuda-stage.yml - parameters: - build_py_parameters: --enable_training --update --build - torch_version: '2.0.0' - opset_version: '17' - cuda_version: '11.8' - cmake_cuda_architectures: 60;61;70;75;80;86 - docker_file: Dockerfile.manylinux2_28_training_cuda11_8 - agent_pool: Onnxruntime-Linux-GPU - upload_wheel: 'yes' - debug_build: false - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - build_pool_name: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml deleted file mode 100644 index 74d299c728911..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml +++ /dev/null @@ -1,16 +0,0 @@ -trigger: none - -stages: -- template: templates/py-packaging-training-cuda-stage.yml - parameters: - # set the paralle count to reduce memory/build_threads to avoid OOM - build_py_parameters: --enable_training --update --build --parallel 8 - torch_version: '2.1.0' - opset_version: '17' - cuda_version: '12.2' - cmake_cuda_architectures: 70;75;80;86;90 - docker_file: Dockerfile.manylinux2_28_training_cuda12_2 - agent_pool: Onnxruntime-Linux-GPU - upload_wheel: 'yes' - debug_build: false - build_pool_name: 'onnxruntime-Ubuntu-2204-Training-CPU' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml deleted file mode 100644 index a53d110a20a7a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml +++ /dev/null @@ -1,65 +0,0 @@ -trigger: none - -resources: - repositories: - - repository: manylinux - type: Github - endpoint: Microsoft - name: pypa/manylinux - ref: 5eda9aded5462201e6310105728d33016e637ea7 - -stages: -- stage: "Python_Packaging_ROCm60_Release" - jobs: - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - -- stage: "Python_Packaging_ROCm60_Debug" - jobs: - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - BuildConfig: 'Debug' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - BuildConfig: 'Debug' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '6.0' - RocmVersionPatchSuffix: ".3" - BuildConfig: 'Debug' - -- stage: "Python_Packaging_ROCm57_Release" - condition: ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true') - jobs: - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '5.7' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '5.7' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '5.7' diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 833e97b437c33..3eafd7350b25b 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -1,3 +1,20 @@ +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +variables: + - template: templates/common-variables.yml + - name: win_trt_folder + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.win_trt_folder_cuda12 }} + stages: - ${{ if or(startsWith(variables['System.CollectionUri'], 'https://dev.azure.com/aiinfra/'),startsWith(variables['System.CollectionUri'], 'https://aiinfra.visualstudio.com/')) }}: - template: templates/web-ci.yml @@ -206,7 +223,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_cuda.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo @@ -226,7 +243,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_trt.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo @@ -377,7 +394,7 @@ stages: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: x64 @@ -411,7 +428,7 @@ stages: steps: - task: UsePythonVersion@0 inputs: - versionSpec: "3.9" + versionSpec: "3.12" addToPath: true architecture: "x64" @@ -447,7 +464,7 @@ stages: steps: - task: UsePythonVersion@0 inputs: - versionSpec: "3.9" + versionSpec: "3.12" addToPath: true architecture: "x64" diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml new file mode 100644 index 0000000000000..9296928ad97e0 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml @@ -0,0 +1,24 @@ +resources: + pipelines: + - pipeline: build + source: 'Python CUDA ALT Packaging Pipeline' + trigger: true + branch: main # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + +stages: + # ****The following Stage depend on all previous tags. *** + # GPU resources are very limited, + # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed + - stage: Linux_Test_CUDA_Alt_x86_64_stage + dependsOn: + jobs: + - template: templates/py-packaging-linux-test-cuda.yml + parameters: + arch: 'x86_64' + machine_pool: 'Onnxruntime-Linux-GPU' + python_wheel_suffix: '_gpu' + timeout: 480 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 + trt_version: '10.6.0.26-1.cuda11.8' + cuda_version: '11.8' + diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml new file mode 100644 index 0000000000000..93a38b212d934 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml @@ -0,0 +1,50 @@ +trigger: none +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +parameters: + - name: enable_linux_cuda + type: boolean + default: true + + - name: enable_windows_cuda + type: boolean + default: true + + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + pool: + name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool + os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS + + stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 79b69bf34cbef..307415b7be16f 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -9,15 +9,15 @@ stages: # ****The following Stage depend on all previous tags. *** # GPU resources are very limited, # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed - - stage: Linux_Test_GPU_x86_64_stage + - stage: Linux_Test_CUDA_x86_64_stage dependsOn: jobs: - - template: stages/jobs/py-linux-cuda-package-test-job.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: - CudaVersion: '12.2' + arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' + python_wheel_suffix: '_gpu' timeout: 480 - build_id: ${{ parameters.build_id }} - project: ${{ parameters.project }} - pipeline: ${{ parameters.pipeline }} + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 + cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml index 3503857a9233c..2e040698fad2a 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -1,12 +1,20 @@ trigger: none - +# The `resources` specify the location and version of the 1ES PT. +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release parameters: - - name: enable_linux_gpu + - name: enable_linux_cuda type: boolean default: true - - name: enable_windows_gpu + + - name: enable_windows_cuda type: boolean default: true + - name: cmake_build_type type: string default: 'Release' @@ -15,28 +23,31 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - - name: cuda_version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 - - name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - - name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' -stages: - - template: stages/py-cuda-packaging-stage.yml - parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} - cmake_build_type: ${{ parameters.cmake_build_type }} - cuda_version: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + pool: + name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool + os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS + + stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml new file mode 100644 index 0000000000000..371d233897c8d --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml @@ -0,0 +1,41 @@ +trigger: none +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +parameters: + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + + pool: + name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool + os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS + + stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_windows_dml: true + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index fc66cd9f145f7..a0e49692220f9 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -42,27 +42,13 @@ stages: # GPU resources are very limited, # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed -- stage: Linux_Test_GPU_x86_64_stage - dependsOn: - - Linux_Test_CPU_x86_64_stage - - Linux_Test_CPU_aarch64_stage - - Packages_Somking_Test - jobs: - - template: templates/py-packaging-linux-test-cuda.yml - parameters: - arch: 'x86_64' - machine_pool: 'Onnxruntime-Linux-GPU' - python_wheel_suffix: '_gpu' - timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 - trt_version: '10.4.0.26-1.cuda11.8' - cuda_version: '11.8' - # if final job not extecuted, it will not run nightlly build - stage: Final dependsOn: - - Linux_Test_GPU_x86_64_stage + - Linux_Test_CPU_x86_64_stage + - Linux_Test_CPU_aarch64_stage + - Packages_Somking_Test jobs: - job: Final # Run this step only if all previous steps are succeeded and (this build was triggered by a resource trigger or it was triggered by another build). diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 7263239c6c7f0..bd33282fd494e 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -4,21 +4,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -69,17 +59,15 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.26.0.240828 + default: 2.28.2.241116 trigger: none stages: -- template: templates/py-packaging-stage.yml +- template: stages/py-cpu-packaging-stage.yml parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} enable_linux_cpu: ${{ parameters.enable_linux_cpu }} enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} @@ -89,4 +77,3 @@ stages: build_py_parameters: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} qnn_sdk_version: ${{ parameters.qnn_sdk_version }} - publish_symbols: true diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 98b5e47c0e2d7..d54b8018c232a 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml index f4022a80b0568..471e911843aed 100644 --- a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml @@ -85,10 +85,10 @@ stages: - job: Linux_C_API_Packaging_ROCm_x64 workspace: clean: all - timeoutInMinutes: 240 + timeoutInMinutes: 480 pool: onnxruntime-Ubuntu2204-AMD-CPU variables: - RocmVersion: '5.6' + RocmVersion: '6.2' RocmVersionPatchSuffix: '' steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime @@ -255,7 +255,7 @@ stages: - task: UsePythonVersion@0 displayName: 'Use Python' inputs: - versionSpec: 3.8 + versionSpec: 3.12 - task: MSBuild@1 displayName: 'Build Nuget Packages' @@ -340,14 +340,3 @@ stages: SpecificArtifact: ${{ parameters.specificArtifact }} CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' BuildId: ${{ parameters.BuildId }} - -- template: templates/publish-nuget-steps.yml - parameters: - download_artifacts_steps: - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' - ArtifactName: 'drop-signed-nuget-ROCm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml new file mode 100644 index 0000000000000..1d2393d8f96d5 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml @@ -0,0 +1,21 @@ +resources: + pipelines: + - pipeline: build + source: 'Nuget ROCM Packaging pipeline' + trigger: + branches: + include: + - main + - rel-* + branch: main + +# ROCm +stages: +- template: templates/publish-nuget-steps.yml + parameters: + stage_name: 'Publish_ROCM_NuGet_Package' + download_artifacts_steps: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-ROCm' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 7bc61268805f2..716383fd61dbb 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -58,6 +58,11 @@ stages: showWarnings: true workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + - template: ../templates/jar-esrp-dll.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + JarFileName: 'onnxruntime_gpu-$(OnnxRuntimeVersion).jar' + - template: ../templates/jar-maven-signing-win.yml parameters: JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' @@ -143,9 +148,9 @@ stages: value: false - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 timeoutInMinutes: 60 steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 545996a4ffc79..47092393e0039 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -42,30 +42,31 @@ jobs: dependsOn: [ ] timeoutInMinutes: ${{ parameters.timeout }} variables: + - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} pool: ${{ parameters.machine_pool }} steps: - checkout: self - task: DownloadPipelineArtifact@2 inputs: - artifact: 'drop-linux-gpu-x86_64' - targetPath: '$(Build.SourcesDirectory)/drop-linux-gpu-x86_64' + artifact: 'linux_gpu_wheel_x86_64' + targetPath: '$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64' ${{ if ne(parameters.build_id, 'latest') }}: buildType: 'specific' project: '${{ parameters.project }}' pipeline: '${{ parameters.pipeline }}' buildVersionToDownload: 'specific' buildId: '${{ parameters.build_id }}' - displayName: 'Download Build Artifacts - drop-linux-gpu-x86_64' + displayName: 'Download Build Artifacts - linux_gpu_wheel_x86_64' - task: DownloadPipelineArtifact@2 inputs: @@ -82,7 +83,7 @@ jobs: - bash: | set -e -x ls $(Build.SourcesDirectory) - mv "$(Build.SourcesDirectory)/drop-linux-gpu-x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Build.SourcesDirectory)/onnxruntime_gpu" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index b8ade5d36f5a1..7133031c84f49 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -135,7 +135,7 @@ stages: - task: UsePythonVersion@0 displayName: 'Use Python' inputs: - versionSpec: 3.8 + versionSpec: 3.12 - task: MSBuild@1 displayName: 'Build Nuget Packages' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index fc6da88917f62..d331c76bc264e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -68,6 +68,7 @@ stages: timeoutInMinutes: 180 pool: 'onnxruntime-Ubuntu2204-AMD-CPU' variables: + - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR ${{ if eq(parameters.CudaVersion, '11.8') }}: value: '11' @@ -75,12 +76,11 @@ stages: value: '12' - name: CUDA_VERSION value: ${{ parameters.CudaVersion }} - - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self clean: true @@ -140,6 +140,7 @@ stages: clean: all pool: 'Onnxruntime-Linux-GPU' variables: + - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR ${{ if eq(parameters.CudaVersion, '11.8') }}: value: '11' @@ -147,9 +148,9 @@ stages: value: '12' - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime submodules: false diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml similarity index 66% rename from tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml rename to tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index edaae227ee78f..72df94c9ea672 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -10,21 +10,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -65,15 +55,11 @@ parameters: - RelWithDebInfo - MinSizeRel -- name: publish_symbols - type: boolean - default: false - # Only applies to QNN packages. - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.26.0.240828 + default: 2.28.2.241116 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: @@ -96,6 +82,10 @@ stages: PythonVersion: '3.12' MsbuildPlatform: x64 buildArch: x64 + Python313_x64: + PythonVersion: '3.13' + MsbuildPlatform: x64 + buildArch: x64 # Training build cannot support Win32 for now because one or more of its python # dependencies does not support Win32. So, don't build a training package for Win32 ${{ if not(contains(parameters.build_py_parameters, '--enable_training')) }}: @@ -124,7 +114,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -138,7 +128,7 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: BatchScript@1 displayName: 'setup env' @@ -147,7 +137,7 @@ stages: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -176,24 +166,12 @@ stages: --enable_pybind --enable_onnx_tests ${{ parameters.build_py_parameters }} - --parallel --use_binskim_compliant_compile_flags --update + --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\onnxruntime.sln' - platform: $(MsbuildPlatform) - configuration: ${{ parameters.cmake_build_type }} - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}' - createLogFile: true - # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -237,7 +215,7 @@ stages: continueOnError: true - powershell: | - python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} Remove-Item -Recurse -Force onnxruntime if ("$(ExtraParam)" -contains "--use_azure") { @@ -247,29 +225,8 @@ stages: python onnx_backend_test_series.py workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' displayName: 'Run Python Tests' - - ${{ if eq(parameters.publish_symbols, true) }}: - - task: PublishSymbols@2 - displayName: 'Publish symbols' - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) - inputs: - SymbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - SearchPattern: | - onnxruntime_pybind11_state.pdb - onnxruntime_providers_shared.pdb - IndexSources: true - SymbolServerType: TeamServices - SymbolExpirationInDays: 3650 - SymbolsArtifactName: 'win_cpu_$(PythonVersion)_$(buildArch)_$(Build.BuildNumber)' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - continueOnError: true - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -277,67 +234,6 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- ${{ if eq(parameters.enable_windows_gpu, true) }}: - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - ${{ if eq(parameters.enable_mac_cpu, true) }}: - stage: Python_Packaging_MacOS dependsOn: [] @@ -358,6 +254,9 @@ stages: PythonVersion: '3.11' Python312: PythonVersion: '3.12' + Python313: + PythonVersion: '3.13' + steps: - checkout: self clean: true @@ -368,9 +267,9 @@ stages: inputs: versionSpec: $(PythonVersion) - - template: use-xcode-version.yml + - template: ../templates/use-xcode-version.yml - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -410,7 +309,7 @@ stages: inputs: ArtifactName: onnxruntime - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -419,7 +318,7 @@ stages: - stage: Python_Packaging_Linux_ARM dependsOn: [] jobs: - - template: py-linux.yml + - template: ../templates/py-linux.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' @@ -430,30 +329,18 @@ stages: - stage: Python_Packaging_Linux_CPU dependsOn: [] jobs: - - template: py-linux.yml - parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: py-linux-gpu.yml + - template: ../templates/py-linux.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - trt_version: '10.4.0.26-1.cuda11.8' - cuda_version: '11.8' - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - stage: Python_Packaging_Windows_ARM64_QNN dependsOn: [] jobs: - - template: py-win-arm64-qnn.yml + - template: ../templates/py-win-arm64-qnn.yml parameters: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -463,7 +350,7 @@ stages: - stage: Python_Packaging_Windows_arm64ec_QNN dependsOn: [] jobs: - - template: py-win-arm64ec-qnn.yml + - template: ../templates/py-win-arm64ec-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -473,7 +360,7 @@ stages: - stage: Python_Packaging_Windows_x64_QNN dependsOn: [] jobs: - - template: py-win-x64-qnn.yml + - template: ../templates/py-win-x64-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -483,7 +370,7 @@ stages: - stage: Python_Packaging_Linux_x64_QNN dependsOn: [] jobs: - - template: py-linux-qnn.yml + - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml deleted file mode 100644 index f9f90b43f0cf6..0000000000000 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: build_py_parameters - displayName: > - Extra parameters to pass to build.py. Don't put newlines in here. - type: string - default: '' - -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - -# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. -- name: cmake_build_type - type: string - displayName: 'Linux packages cmake build type. Linux Only.' - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: cuda_version - type: string - displayName: 'CUDA version. Windows Only.' - default: '12.2' - values: - - 11.8 - - 12.2 - -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - -- name: PythonVersions - type: object - displayName: 'Python versions to build' - default: - - '3.8' - - '3.9' - - '3.10' - - '3.11' - - '3.12' - -stages: - - ${{ if eq(parameters.enable_windows_gpu, true) }}: - - ${{ each python_version in parameters.PythonVersions }}: - - template: ../templates/py-win-gpu.yml - parameters: - PYTHON_VERSION: ${{ python_version }} - EP_NAME: gpu - CudaVersion: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - ${{ if eq(parameters.cuda_version, '11.8') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ${{ if eq(parameters.cuda_version, '12.2') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: ../templates/py-linux-gpu.yml - parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - cuda_version: ${{ parameters.cuda_version }} - ${{ if eq(parameters.cuda_version, '11.8') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 - trt_version: 10.4.0.26-1.cuda11.8 - ${{ if eq(parameters.cuda_version, '12.2') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 - trt_version: 10.4.0.26-1.cuda12.6 diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml new file mode 100644 index 0000000000000..947e4f99b984f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -0,0 +1,83 @@ +parameters: +- name: build_py_parameters + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +- name: enable_linux_cuda + displayName: 'Whether Linux CUDA package is built.' + type: boolean + default: false + +- name: enable_windows_cuda + displayName: 'Whether Windows CUDA package is built.' + type: boolean + default: false + +- name: enable_windows_dml + displayName: 'Whether Windows DML package is built.' + type: boolean + default: false + +# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. +- name: cmake_build_type + type: string + displayName: 'Linux packages cmake build type. Linux Only.' + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: cuda_version + type: string + displayName: 'CUDA version. Windows Only.' + default: '12.2' + values: + - 11.8 + - 12.2 + +- name: PythonVersions + type: object + displayName: 'Python versions to build' + default: + - '3.10' + - '3.11' + - '3.12' + - '3.13' + +stages: + - ${{ if eq(parameters.enable_windows_cuda, true) }}: + - ${{ each python_version in parameters.PythonVersions }}: + - template: py-win-gpu-stage.yml + parameters: + PYTHON_VERSION: ${{ python_version }} + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + EP_BUILD_FLAGS: --enable_lto --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + use_tensorrt: True + + - ${{ if eq(parameters.enable_linux_cuda, true) }}: + - template: py-linux-gpu-stage.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: ${{ parameters.cuda_version }} + ${{ if eq(parameters.cuda_version, '11.8') }}: + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 + ${{ if eq(parameters.cuda_version, '12.2') }}: + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 + + - ${{ if eq(parameters.enable_windows_dml, true) }}: + - ${{ each python_version in parameters.PythonVersions }}: + - template: py-win-gpu-stage.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' + PYTHON_VERSION: ${{ python_version }} + EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos + EP_NAME: directml + cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml similarity index 53% rename from tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml index d19472bcbab5a..3f26d2d5aeca3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml @@ -20,12 +20,6 @@ parameters: - name: docker_base_image type: string -- name: trt_version - type: string - default: '10.4.0.26-1.cuda11.8' - values: - - 10.4.0.26-1.cuda11.8 - - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' @@ -41,7 +35,27 @@ stages: timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: linux + templateContext: + codeSignValidation: + enabled: true + break: true + psscriptanalyzer: + enabled: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + targetPathPattern: '\".*.so\"' + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/dist + artifactName: onnxruntime_gpu + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/${{ parameters.cmake_build_type }} + artifactName: linux_gpu_wheel_${{ parameters.arch }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - name: skipComponentGovernanceDetection @@ -51,18 +65,24 @@ stages: value: -x ${{ parameters.extra_build_arg }} ${{ if eq(parameters.extra_build_arg, '') }}: value: '' + - template: ../templates/common-variables.yml + - name: trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self clean: true submodules: recursive - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml + - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ variables.trt_version }} --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} @@ -73,17 +93,19 @@ stages: filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime_gpu - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Test Binaries' - inputs: - artifactName: 'drop-linux-gpu-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/Release' + - script: | + set -e -x + mv $(Build.BinariesDirectory)/${{ parameters.cmake_build_type }} ./${{ parameters.cmake_build_type }} + mv $(Build.BinariesDirectory)/dist ./dist + pushd dist + find . -name \*.whl -exec unzip -qq -o {} \; + rm -r onnxruntime + popd + pushd ${{ parameters.cmake_build_type }} + find . -name \*.whl -exec unzip -qq -o {} \; + popd + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'Move files' - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml similarity index 51% rename from tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index e89227d51de32..aa7f2845fc0fa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -12,10 +12,6 @@ parameters: - name: EP_BUILD_FLAGS type: string -- name: ENV_SETUP_SCRIPT - type: string - default: '' - - name: BUILD_PY_PARAMETERS displayName: > Extra parameters to pass to build.py. Don't put newlines in here. @@ -28,16 +24,6 @@ parameters: - 11.8 - 12.2 -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - - name: cmake_build_type type: string displayName: 'Linux packages cmake build type. Linux Only.' @@ -47,8 +33,8 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - -- name: publish_symbols + +- name: use_tensorrt type: boolean default: false @@ -62,10 +48,40 @@ stages: clean: all pool: name: onnxruntime-Win-CPU-2022 + os: windows + templateContext: + codeSignValidation: + enabled: true + break: true + psscriptanalyzer: + enabled: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + targetPathPattern: '+:file|*.dll;-:file|DirectML.dll' + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: win_${{ parameters.EP_NAME }}_wheel_${{ parameters.PYTHON_VERSION }} variables: - GRADLE_OPTS: '-Dorg.gradle.daemon=false' - VSGenerator: 'Visual Studio 17 2022' - CUDA_MODULE_LOADING: 'LAZY' + - template: ../templates/common-variables.yml + - name: GRADLE_OPTS + value: '-Dorg.gradle.daemon=false' + - name: VSGenerator + value: 'Visual Studio 17 2022' + - name: CUDA_MODULE_LOADING + value: 'LAZY' + - name: win_trt_folder + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.win_trt_folder_cuda12 }} + - name: trt_build_flag + ${{ if eq(parameters.use_tensorrt, true) }}: + value: '--use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}"' + ${{ if eq(parameters.use_tensorrt, false) }}: + value: '' steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -75,7 +91,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -83,31 +99,20 @@ stages: addToPath: true architecture: 'x64' - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false + artifactFeeds: 'Lotus' - - template: download-deps.yml - - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: - DownloadCUDA: true - ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: - DownloadTRT: true + - template: ../templates/download-deps.yml - - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/download_win_gpu_library.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: - DownloadCUDA: true - ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: - DownloadTRT: true + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), eq(parameters.use_tensorrt, true)) }}: + DownloadCUDA: true + DownloadTRT: ${{ parameters.use_tensorrt }} - task: PythonScript@0 displayName: 'Update deps.txt' @@ -116,14 +121,7 @@ stages: arguments: --new_dir $(Build.BinariesDirectory)/deps workingDirectory: $(Build.BinariesDirectory) - - task: PowerShell@2 - displayName: 'Install ONNX' - inputs: - filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' - workingDirectory: '$(Build.BinariesDirectory)' - arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\installed -build_config ${{ parameters.cmake_build_type }} - - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: PythonScript@0 displayName: 'Generate cmake config' @@ -136,24 +134,12 @@ stages: --cmake_generator "$(VSGenerator)" --enable_pybind --enable_onnx_tests - --parallel --use_binskim_compliant_compile_flags --update - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} - workingDirectory: '$(Build.BinariesDirectory)' - - # building with build.py so the parallelization parameters are added to the msbuild command - - task: PythonScript@0 - displayName: 'Build' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config ${{ parameters.cmake_build_type }} - --build_dir $(Build.BinariesDirectory) - --parallel --build - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + --parallel --use_binskim_compliant_compile_flags --update --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} ${{ variables.trt_build_flag }} workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -174,49 +160,12 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_${{ parameters.EP_NAME }} - - - ${{ if eq(parameters.publish_symbols, true) }}: - - task: PublishSymbols@2 - displayName: 'Publish symbols' - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) - inputs: - SymbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - SearchPattern: | - onnxruntime_pybind11_state.pdb - onnxruntime_providers_shared.pdb - IndexSources: true - SymbolServerType: TeamServices - SymbolExpirationInDays: 3650 - SymbolsArtifactName: 'win_${{ parameters.EP_NAME }}_${{ parameters.PYTHON_VERSION }}_$(Build.BuildNumber)' - - script: | 7z x *.whl workingDirectory: '$(Build.ArtifactStagingDirectory)' displayName: 'unzip the package' - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll;-:file|$(Build.ArtifactStagingDirectory)\**\DirectML.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -243,24 +192,28 @@ stages: addToPath: true architecture: 'x64' - - template: flex-downloadPipelineArtifact.yml + - template: ../templates/flex-downloadPipelineArtifact.yml parameters: - ArtifactName: onnxruntime_${{ parameters.EP_NAME }} + ArtifactName: win_${{ parameters.EP_NAME }}_wheel_${{ parameters.PYTHON_VERSION }} StepName: 'Download Pipeline Artifact - Windows GPU Build' TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - powershell: | - pushd onnxruntime/test/python - python -m pip install --upgrade pip - python -m pip install -r requirements.txt - popd - workingDirectory: '$(Build.SourcesDirectory)' + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), eq(parameters.use_tensorrt, true)) }}: + DownloadCUDA: true + DownloadTRT: ${{ parameters.use_tensorrt }} + + - task: PowerShell@2 displayName: 'Install ONNX' + inputs: + filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' + workingDirectory: '$(Build.BinariesDirectory)' + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\installed -build_config ${{ parameters.cmake_build_type }} - powershell: | - python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} mkdir -p $(Agent.TempDirectory)\ort_test_data Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data diff --git a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml index acce2a4098ed0..4d9606d82ced2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml @@ -91,7 +91,7 @@ stages: -e BUILD_REASON=$(Build.Reason) \ -e BUILD_BRANCH=$(Build.SourceBranch) \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ --build_dir /build/1a \ ${BINARY_SIZE_THRESHOLD_ARGS} \ "/onnxruntime_src/${{ parameters.BuildConfigFile }}" @@ -147,7 +147,7 @@ stages: -e BUILD_REASON=$(Build.Reason) \ -e BUILD_BRANCH=$(Build.SourceBranch) \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ --build_dir /build/1b \ --with_debug_info \ "/onnxruntime_src/${{ parameters.BuildConfigFile }}" diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index fe83a91b2f3d4..29caa7fa4955a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -14,6 +14,16 @@ parameters: type: string default: 'onnxruntime-android' +- name: ReleaseVersionSuffix + displayName: Release Version Suffix + type: string + default: '' + +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: '2.28.0.241029' + jobs: - job: Final_AAR_Testing_Android_${{ parameters.job_name_suffix }} workspace: @@ -45,36 +55,61 @@ jobs: - template: use-android-ndk.yml - - template: use-android-emulator.yml - parameters: - create: true - start: true - - script: | - set -e -x - mkdir android_test - cd android_test - cp -av $(Build.SourcesDirectory)/java/src/test/android ./ - cd ./android - mkdir -p app/libs - cp $(Build.BinariesDirectory)/final-android-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion).aar app/libs/onnxruntime-android.aar - $(Build.SourcesDirectory)/java/gradlew --no-daemon clean connectedDebugAndroidTest --stacktrace - displayName: Run E2E test using Emulator + set -e -x + mkdir -p android_test/android/app/libs + cd android_test/android + cp -av $(Build.SourcesDirectory)/java/src/test/android/* ./ + cp $(Build.BinariesDirectory)/final-android-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar + displayName: Copy Android test files and AAR to android_test directory workingDirectory: $(Build.BinariesDirectory) - - template: use-android-emulator.yml - parameters: - stop: true + # skip emulator tests for qnn package as there are no arm64-v8a emulators and no qnn libraries for x86 + - ${{ if not(contains(parameters.packageName, 'qnn')) }}: + - template: use-android-emulator.yml + parameters: + create: true + start: true + + - script: | + set -e -x + cd android_test/android + $(Build.SourcesDirectory)/java/gradlew --no-daemon clean connectedDebugAndroidTest --stacktrace + displayName: Run E2E test using Emulator + workingDirectory: $(Build.BinariesDirectory) + + - template: use-android-emulator.yml + parameters: + stop: true + + - ${{ else }}: + - script: | + # QNN SDK version string, expected format: 2.28.0.241029 + # Extract the first three parts of the version string to get the Maven package version (e.g., 2.28.0) + QnnMavenPackageVersion=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + echo "QnnMavenPackageVersion: $QnnMavenPackageVersion" + echo "##vso[task.setvariable variable=QnnMavenPackageVersion]$QnnMavenPackageVersion" + displayName: Trim QNN SDK version to major.minor.patch + + - script: | + set -e -x + # build apks for qnn package as they are not built in the emulator test step + $(Build.SourcesDirectory)/java/gradlew --no-daemon clean assembleDebug assembleAndroidTest -DqnnVersion=$(QnnMavenPackageVersion) --stacktrace + displayName: Build QNN APK + workingDirectory: $(Build.BinariesDirectory)/android_test/android # we run e2e tests on one older device (Pixel 3) and one newer device (Galaxy 23) - script: | set -e -x pip install requests + python $(Build.SourcesDirectory)/tools/python/upload_and_run_browserstack_tests.py \ --test_platform espresso \ - --app_apk_path "debug/app-debug.apk" \ - --test_apk_path "androidTest/debug/app-debug-androidTest.apk" \ - --devices "Samsung Galaxy S23-13.0" "Google Pixel 3-9.0" + --app_path "debug/app-debug.apk" \ + --test_path "androidTest/debug/app-debug-androidTest.apk" \ + --devices "Samsung Galaxy S23-13.0" "Google Pixel 3-9.0" \ + --build_tag "${{ parameters.packageName }}" + displayName: Run E2E tests using Browserstack workingDirectory: $(Build.BinariesDirectory)/android_test/android/app/build/outputs/apk timeoutInMinutes: 15 diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 73f4620440a6c..c38736edd58f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -43,6 +43,16 @@ parameters: displayName: Use GPG to sign the jars type: boolean +- name: ReleaseVersionSuffix + displayName: Release Version Suffix + type: string + default: '' + +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: '2.28.0.241029' + jobs: - job: Android_Java_API_AAR_Packaging_${{ parameters.job_name_suffix }} timeoutInMinutes: 120 @@ -78,6 +88,11 @@ jobs: - template: use-android-ndk.yml + - ${{ if contains(parameters.packageName, 'qnn') }}: + - template: jobs/download_linux_qnn_sdk.yml + parameters: + QnnSDKVersion: '${{parameters.QnnSDKVersion}}' + - task: CmdLine@2 displayName: Build Android AAR Packages inputs: @@ -88,6 +103,15 @@ jobs: cp ${{parameters.buildSettings}} $(Build.BinariesDirectory)/.build_settings/build_settings.json [ -f "${{parameters.includedOpsConfig}}" ] && \ cp ${{parameters.includedOpsConfig}} $(Build.BinariesDirectory)/.build_settings/include_ops_and_types.config + + # Mount qnn volume if building qnn android package + if [[ ${{ parameters.packageName }} == *qnn* ]]; then + QNN_VOLUME="--volume $(QnnSDKRootDir):/qnn_home" + USE_QNN="1" + else + QNN_VOLUME="" + USE_QNN="0" + fi docker run --rm \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ @@ -95,14 +119,16 @@ jobs: --volume $NDK_HOME:/ndk_home \ --volume $(artifacts_directory):/home/onnxruntimedev/.artifacts \ --volume $(Build.BinariesDirectory)/.build_settings:/home/onnxruntimedev/.build_settings \ + $QNN_VOLUME \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ -e BUILD_CONFIG=${{parameters.buildConfig}} \ -e ORT_VERSION=$(OnnxRuntimeVersion) \ -e PUBLISH_EXECUTABLES=${{parameters.publish_executables}} \ -e PACKAGE_NAME=${{parameters.packageName}} \ + -e RELEASE_VERSION_SUFFIX=${{parameters.ReleaseVersionSuffix}} \ onnxruntimecpubuild \ - /bin/bash /onnxruntime_src/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh + /bin/bash /onnxruntime_src/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh $USE_QNN workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index e99538a595f69..b105e919c5b12 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -48,6 +48,11 @@ parameters: type: string default: '0' +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: 2.28.0.241029 + stages: - template: linux-cpu-packaging-pipeline.yml parameters: @@ -62,7 +67,9 @@ stages: DoEsrp: ${{ parameters.DoEsrp }} - stage: Android_Java_API_AAR_Packaging_Full - dependsOn: [] + dependsOn: Setup # Setup stage defined in set_packaging_variables_stage.yml creates the ReleaseVersionSuffix variable + variables: + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] jobs: - template: android-java-api-aar.yml parameters: @@ -72,11 +79,38 @@ stages: job_name_suffix: 'Full' publish_executables: '1' enable_code_sign: ${{ parameters.DoEsrp }} + packageName: 'onnxruntime-android' + ReleaseVersionSuffix: $(ReleaseVersionSuffix) - template: android-java-api-aar-test.yml parameters: artifactName: 'onnxruntime-android-full-aar' job_name_suffix: 'Full' + ReleaseVersionSuffix: $(ReleaseVersionSuffix) + +- stage: Android_Java_API_AAR_Packaging_QNN + dependsOn: Setup # Setup stage defined in set_packaging_variables_stage.yml creates the ReleaseVersionSuffix variable + variables: + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + jobs: + - template: android-java-api-aar.yml + parameters: + buildConfig: 'Release' + buildSettings: '$(Build.SourcesDirectory)/tools/ci_build/github/android/default_qnn_aar_build_settings.json' + artifactName: 'onnxruntime-android-qnn-aar' + job_name_suffix: 'QNN' + publish_executables: '0' + enable_code_sign: ${{ parameters.DoEsrp }} + packageName: 'onnxruntime-android-qnn' + ReleaseVersionSuffix: $(ReleaseVersionSuffix) + QnnSDKVersion: ${{ parameters.QnnSDKVersion }} + + - template: android-java-api-aar-test.yml + parameters: + artifactName: 'onnxruntime-android-qnn-aar' + job_name_suffix: 'QNN' + packageName: 'onnxruntime-android-qnn' + QnnSDKVersion: ${{ parameters.QnnSDKVersion }} - stage: iOS_Full_xcframework dependsOn: [] @@ -238,6 +272,11 @@ stages: showWarnings: true workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + - template: jar-esrp-dll.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + JarFileName: 'onnxruntime-$(OnnxRuntimeVersion).jar' + - template: jar-maven-signing-win.yml parameters: JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' @@ -419,7 +458,7 @@ stages: - task: UsePythonVersion@0 displayName: 'Use Python' inputs: - versionSpec: 3.8 + versionSpec: 3.12 - task: MSBuild@1 displayName: 'Build Nuget Packages' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index 0f4328f75e1bd..d3b3315ebb04c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -5,9 +5,6 @@ parameters: type: string default: '' -- name: BaseImage - type: string - - name: OnnxruntimeArch type: string @@ -50,7 +47,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging - ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: @@ -58,16 +55,17 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging UpdateDepsTxt: false - task: CmdLine@2 inputs: script: | + set -e -x mkdir -p $HOME/.onnx docker run --rm --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging /bin/bash -c "python3.9 \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging /bin/bash -c "python3.12 \ /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/installed" workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml index e7f703fa592a3..d35bed69ee409 100644 --- a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml +++ b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml @@ -1,3 +1,7 @@ variables: - common_cuda_version: '11.8' - common_cuda_baseimg: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8' + common_trt_version: '10.6.0.26' + # As for Debian installation, replace '-1.' by '-1+' when assigning trt version below + linux_trt_version_cuda11: ${{ variables.common_trt_version }}-1.cuda11.8 + linux_trt_version_cuda12: ${{ variables.common_trt_version }}-1.cuda12.6 + win_trt_folder_cuda11: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-11.8 + win_trt_folder_cuda12: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-12.6 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index a38db0aa57d19..949479fb8b5e4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.194 + version: 1.0.201 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.194 + version: 1.0.201 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml b/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml deleted file mode 100644 index 51be73d4c658a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml +++ /dev/null @@ -1,12 +0,0 @@ -# Install appcenter CLI - -parameters: -- name: appcenterVersion - type: string - default: "2.13.7" - -steps: -- bash: | - set -e -x - npm install -g appcenter-cli@${{ parameters.appcenterVersion }} - displayName: Install appcenter CLI ${{ parameters.appcenterVersion }} diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml new file mode 100644 index 0000000000000..b59ba551c222f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml @@ -0,0 +1,30 @@ +parameters: +- name: JarFileDirectory + type: string + default: '' + +- name: JarFileName + type: string + default: '' + +steps: + - task: PowerShell@2 + displayName: 'ESRP Jar - Extract Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: extract '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' + workingDirectory: '$(Build.BinariesDirectory)' + + - template: win-esrp-dll.yml + parameters: + FolderPath: '${{ parameters.JarFileDirectory }}\jar_extracted_full_files' + DisplayName: 'ESRP Jar - Sign Dlls' + + - task: PowerShell@2 + displayName: 'ESRP Jar - Repack Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: repack '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' + workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml index ca7e3f6148e26..d14952e544e5e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -45,7 +45,8 @@ steps: for file in $(find $jar_file_directory -type f); do echo "Adding checksum of sha256 to file: $file" - sha256sum $file | awk '{print $1}' >$file.sha256 + sha256_value=$(sha256sum $file | awk '{print $1}') + echo $sha256_value" *"$(basename "$file") >$file.sha256 echo "Added checksum of sha256 to file: $file" done diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml index 182a2ebe3b4c9..5681b3568bae1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml @@ -15,6 +15,7 @@ steps: displayName: 'Sign jar files: GnuPG and sha256' inputs: targetType: 'inline' + pwsh: true workingDirectory: '$(Build.SourcesDirectory)' script: | $jar_file_directory = '${{ parameters.JarFileDirectory }}' @@ -53,15 +54,22 @@ steps: Write-Host "GnuPG signed to file: "$file_path } + $PSDefaultParameterValues['Out-File:Encoding'] = 'utf8NoBOM' + $sha256sum_exe_path = "C:\Program Files\Git\usr\bin\sha256sum.exe" $targeting_asc_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name + $original_location = Get-Location + Set-Location $jar_file_directory foreach ($file in $targeting_asc_files) { - $file_path = Join-Path $jar_file_directory -ChildPath $file - Write-Host "Adding checksum of sha256 to file: "$file_path - $file_path_sha256 = $file_path + ".sha256" - CertUtil -hashfile $file_path SHA256 - CertUtil -hashfile $file_path SHA256 | find /v `"hash`" | Out-File -FilePath $file_path_sha256 - Write-Host "Added checksum of sha256 to file: "$file_path + Write-Host "Adding checksum of sha256 to file: "$file + $file_path_sha256 = $file + ".sha256" + & $sha256sum_exe_path $file 1>$file_path_sha256 + if ($lastExitCode -ne 0) { + Write-Host -Object "sha256sum command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "Added checksum of sha256 to file: "$file } + Set-Location $original_location Write-Host "GnuPG and sha256 signing to files completed." Write-Host "Deleting GnuPG key files." diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 4aedd2f8564c1..179a846509cc1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.26.0.240828' + default: '2.28.2.241116' steps: - script: | @@ -16,6 +16,29 @@ steps: echo $(QnnSDKRootDir) displayName: 'Print QnnSDKRootDir after downloading QNN SDK' + - script: | + set -x + sdk_file="$(QnnSDKRootDir)/sdk.yaml" + # Parse the sdk.yaml file to get the QNN SDK version downloaded + downloaded_qnn_sdk_version=$(grep '^version:' "$sdk_file" | head -n 1 | cut -d':' -f2 | xargs | cut -d'.' -f1-3 | tr -d '\r') + + # Extract major.minor.patch part from QnnSDKVersion passed as parameter + expected_qnn_sdk_version=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + + if [[ -z "$downloaded_qnn_sdk_version" ]]; then + echo "QNN version not found in sdk.yaml." + exit 1 + fi + + # Compare provided version with version from sdk.yaml + if [[ "$downloaded_qnn_sdk_version" == "$expected_qnn_sdk_version" ]]; then + echo "Success: QnnSDKVersion matches sdk.yaml version ($downloaded_qnn_sdk_version)." + else + echo "Error: QnnSDKVersion ($expected_qnn_sdk_version) does not match sdk.yaml version ($downloaded_qnn_sdk_version) in the QNN SDK directory" + exit 1 + fi + displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" + - script: | azcopy cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) displayName: 'Download Qualcomm AI Hub license' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index e196ecb312f96..ae54b3849a862 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -13,10 +13,10 @@ parameters: - 12.2 - name: TrtVersion type: string - default: '10.4.0.26' + default: '10.6.0.26' values: - 8.6.1.6 - - 10.4.0.26 + - 10.6.0.26 steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: @@ -42,7 +42,7 @@ steps: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.0" displayName: Set trtCudaVersion - - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.4.0.26')) }}: + - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.6.0.26')) }}: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.6" displayName: Set trtCudaVersion diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index eff49302eb33d..9df8b249f681e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.26.0.240828' + default: '2.28.2.241116' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index 6a2b7f4566b61..dfaf237a711fe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -13,6 +13,12 @@ parameters: - name: SecondaryCUDAVersion type: string default: '11.8' + - name: win_trt_folder_cuda11 + type: string + default: 'TensorRT-10.6.0.26.Windows10.x86_64.cuda-11.8' + - name: win_trt_folder_cuda12 + type: string + default: 'TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6' steps: - ${{ if eq(parameters.DownloadCUDA, 'true') }}: @@ -24,11 +30,11 @@ steps: displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' - ${{ if eq(parameters.DownloadTRT, 'true') }}: - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda11 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ parameters.win_trt_folder_cuda11 }}' - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda12 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ variables.win_trt_folder_cuda12 }}' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index ef48244bbb299..a8a5f13b1f73f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -34,7 +34,7 @@ parameters: steps: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: ${{parameters.BuildArch}} diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 34de1201fa049..e8f391a73fa7b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -102,14 +102,14 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' force32bit: ${{ parameters.isX86 }} # Our build machine doesn't have java x86 - ${{ if eq(parameters.buildArch, 'x64') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' @@ -216,7 +216,7 @@ jobs: - ${{ if eq(parameters.EnablePython, true) }}: - powershell: | - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml -qq + python3 -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 8972d55f6e190..7ac2e3a8addb6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -29,7 +29,6 @@ stages: - template: c-api-linux-cpu.yml parameters: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' @@ -40,7 +39,6 @@ stages: - template: c-api-linux-cpu.yml parameters: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - BaseImage: 'arm64v8/almalinux:8' OnnxruntimeArch: 'aarch64' OnnxruntimeNodejsBindingArch: 'arm64' PoolName: 'onnxruntime-linux-ARM64-CPU-2019' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 2ab432e94fcbd..41ba5c3868f5e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -73,7 +73,7 @@ jobs: displayName: 'Checkout submodules' - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: $(buildArch) - template: download-deps.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 080079388a76c..ab31e592d7d71 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -68,9 +68,6 @@ stages: jobs: - job: MacOS_C_API_Package_Publish pool: - ${{ if eq(parameters.DoESRP, true)}}: - vmImage: 'macOS-12' - ${{ else }}: vmImage: 'macOS-13' steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 3b661d9eb2dc6..7a1addffee0e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -45,16 +45,17 @@ jobs: submodules: none - task: UsePythonVersion@0 - displayName: Use Python 3.11 + displayName: Use Python 3.10 inputs: - versionSpec: 3.11 + versionSpec: 3.10 + - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: "x64" jdkSourceOption: 'PreInstalled' @@ -71,32 +72,19 @@ jobs: arguments: --new_dir $(Build.BinariesDirectory)/deps workingDirectory: $(Build.BinariesDirectory) - - template: mac-build-step-with-cache.yml - parameters: - WithCache: ${{ parameters.WithCache }} - Today: $(TODAY) - AdditionalKey: ' protobuf | "$(Agent.OS)" | $(Build.SourcesDirectory)/cmake/deps.txt, $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_protobuf.sh' - CacheDir: $(PROTO_CACHE_DIR) - ChangeEveryCommit: false - BuildStep: - - script: | - set -e -x - pushd . - $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_protobuf.sh -d $(Build.SourcesDirectory)/cmake/deps.txt -p $(Build.BinariesDirectory)/installed - popd - export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH - export ONNX_ML=1 - export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - displayName: 'Install dependencies' - env: - CCACHE_DIR: $(PROTO_CACHE_DIR) + - script: | + set -e -x + export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH + export ONNX_ML=1 + export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' + - ${{ if eq(parameters.MacosArch, 'universal2') }}: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" BuildJava: false BuildNodejs: false WithCache: ${{ parameters.WithCache }} @@ -108,7 +96,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} @@ -120,7 +108,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index 5cfa135135dca..90055cbbc6c3e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -233,7 +233,7 @@ stages: - task: UsePythonVersion@0 displayName: 'Use Python' inputs: - versionSpec: 3.8 + versionSpec: 3.12 - task: MSBuild@1 displayName: 'Build Nuget Packages' diff --git a/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml deleted file mode 100644 index 5f073433265fa..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml +++ /dev/null @@ -1,41 +0,0 @@ -parameters: -- name: DockerImageTag - type: string -- name: BuildConfig - type: string - -steps: - -- template: jobs/download_training_test_data.yml - - # Entry point for all ORTModule tests - # The onnxruntime folder is deleted in the build directory - # to enforce use of the onnxruntime wheel - # Uninstall orttraining requirements.txt and install ortmodule requirements.txt before running tests. -- script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/${{ parameters.BuildConfig }}:/build \ - --volume $(Agent.TempDirectory)/mnist:/mnist \ - ${{ parameters.DockerImageTag }} \ - bash -c "rm -rf /build/onnxruntime/ && python3 -m pip show torch && python3 -m pip install torch==2.3.1+cu118 --index-url https://download.pytorch.org/whl/cu118 && python3 -m pip install /build/dist/onnxruntime*.whl && python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install && /build/launch_test.py --cmd_line_with_args 'python orttraining_ortmodule_tests.py --mnist /mnist --bert_data /bert_data/hf_data/glue_data/CoLA/original/raw' --cwd /build" \ - displayName: 'Run orttraining_ortmodule_tests.py' - condition: succeededOrFailed() - timeoutInMinutes: 60 - -# Entry point for all ort training api tests -- script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/${{ parameters.BuildConfig }}:/build \ - ${{ parameters.DockerImageTag }} \ - bash -c "rm -rf /build/onnxruntime/ && python3 -m pip install /build/dist/onnxruntime*.whl && python3 -m pip install torch==2.3.1+cu118 --index-url https://download.pytorch.org/whl/cu118 && /build/launch_test.py --cmd_line_with_args 'python orttraining_test_ort_apis.py --cwd /build' --cwd /build" \ - displayName: 'Run ORT Training APIs Tests' - condition: succeededOrFailed() - timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml index 8639a5ca0a55d..6e13db553629e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml @@ -34,7 +34,7 @@ stages: - task: UsePythonVersion@0 inputs: - versionSpec: '3.9' + versionSpec: '3.12' addToPath: true - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 6220a9a46c312..b1cec2284df65 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 jobs: - job: Linux_py_qnn_Wheels_x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index dd9d2412f8f91..c7becac763e28 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -26,8 +26,16 @@ parameters: type: string default: '' +- name: ep + type: string + default: 'cpu' + +- name: python_exe_path + type: string + default: '' + jobs: -- job: Linux_py_Wheels_${{ parameters.arch }} +- job: Linux_py_Wheels_${{ parameters.arch }}_${{parameters.ep}} timeoutInMinutes: 240 workspace: clean: all @@ -42,9 +50,15 @@ jobs: value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - name: extra_build_args ${{ if ne(parameters.extra_build_arg, '') }}: - value: -x ${{ parameters.extra_build_arg }} + value: '-x ${{ parameters.extra_build_arg }}' ${{ if eq(parameters.extra_build_arg, '') }}: value: '' + - name: python_exe_path + ${{ if ne(parameters.python_exe_path, '') }}: + value: '-p ${{ parameters.python_exe_path }}' + ${{ if eq(parameters.python_exe_path, '') }}: + value: '' + steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -78,7 +92,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) ${{ if eq(parameters.with_cache, 'true') }}: env: ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" @@ -87,14 +101,14 @@ jobs: displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime + ArtifactName: onnxruntime-${{ parameters.ep }} - task: PublishPipelineArtifact@0 displayName: 'Publish Test Binaries' inputs: - artifactName: 'drop-linux-cpu-${{ parameters.arch }}' + artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - template: component-governance-component-detection-steps.yml parameters : - condition : 'succeeded' \ No newline at end of file + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index ff54dc647447c..10a0354979ae3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -28,6 +28,8 @@ jobs: PythonVersion: '3.11' Python312: PythonVersion: '3.12' + Python313: + PythonVersion: '3.13' steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml index 0c7c356393b54..bfa6b0d32cab5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -28,6 +28,10 @@ parameters: type: number default: 120 +- name: ep + type: string + default: 'cpu' + jobs: - job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} timeoutInMinutes: ${{ parameters.timeout }} @@ -43,30 +47,30 @@ jobs: # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - download: current # pipeline resource identifier. - artifact: 'drop-linux-cpu-${{ parameters.arch }}' + artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: current # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x - mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. - artifact: 'drop-linux-cpu-${{ parameters.arch }}' + artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: build # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x ls $(Pipeline.Workspace)/build - mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 0c3cd60a712fb..0473fc199a991 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -16,12 +16,6 @@ parameters: - name: docker_base_image type: string -- name: trt_version - type: string - default: '10.4.0.26-1.cuda11.8' - values: - - 10.4.0.26-1.cuda11.8 - - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' @@ -47,7 +41,14 @@ jobs: - job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} timeoutInMinutes: ${{ parameters.timeout }} variables: - skipComponentGovernanceDetection: true + - template: common-variables.yml + - name: skipComponentGovernanceDetection + value: true + - name: trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} workspace: clean: all pool: ${{ parameters.machine_pool }} @@ -61,7 +62,7 @@ jobs: # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. - artifact: 'drop-linux-gpu-${{ parameters.arch }}' + artifact: 'linux_gpu_wheel_${{ parameters.arch }}' - download: build # pipeline resource identifier. artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' @@ -69,7 +70,7 @@ jobs: - bash: | set -e -x ls $(Pipeline.Workspace)/build - mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/linux_gpu_wheel_${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; @@ -92,7 +93,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ variables.trt_version }} --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: Bash@3 @@ -100,7 +101,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockertest.sh - arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index 8a6434e757a3c..7f3a61997b2f8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -241,7 +241,7 @@ stages: **/*.dll - powershell: | - python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --force-reinstall --upgrade $_.fullname tabulate} python -m pip install protobuf==3.18.1 Remove-Item -Recurse -Force onnxruntime @@ -334,7 +334,7 @@ stages: rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 sudo rm -f /build /onnxruntime_src sudo ln -s $(Build.SourcesDirectory) /onnxruntime_src - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml -qq + python3 -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt # Test ORT with the latest ONNX release. sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt @@ -379,9 +379,10 @@ stages: pool: 'onnxruntime-Win2022-GPU-A10' timeoutInMinutes: 300 variables: + - template: common-variables.yml CUDA_VERSION: '11.8' buildArch: x64 - EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" + EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }}" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" EnvSetupScript: setup_env_gpu.bat EP_NAME: gpu VSGenerator: 'Visual Studio 17 2022' @@ -497,7 +498,7 @@ stages: **/*.dll - powershell: | - python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} Remove-Item -Recurse -Force onnxruntime python onnx_backend_test_series.py diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml deleted file mode 100644 index fc163d17e44a9..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml +++ /dev/null @@ -1,209 +0,0 @@ -parameters: - build_py_parameters: '' - torch_version: '' - opset_version: '' - cuda_version: '' - cmake_cuda_architectures: '' - docker_file: '' - upload_wheel: '' - debug_build: '' - python_version: '' - stage_name: '' - SpecificArtifact: false - BuildId: '0' - build_pool_name: '' - -stages: - - stage: Build_${{ parameters.stage_name }} - variables: - - name: isMain - value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - - name: buildConfig - ${{ if eq(parameters['debug_build'], 'true') }}: - value: 'Debug' - ${{ else }}: - value: 'Release' - - name: PythonVersion - value: ${{ parameters.python_version }} - - name: Repository - value: onnxruntimetraininggpubuild_cu${{ replace(parameters.cuda_version, '.', '') }}_py${{ replace(parameters.python_version, '.', '') }} - dependsOn: [] - - jobs: - - job: Build - pool: ${{ parameters.build_pool_name }} - timeoutInMinutes: 180 - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "Branch is "${{ variables['Build.SourceBranch'] }} && \ - echo "isMain is "${{ variables['isMain'] }} && \ - echo "final_storage is "${{ variables['finalStorage'] }} - - - checkout: self - clean: true - submodules: recursive - - - template: set-python-manylinux-variables-step.yml - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }} - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg TORCH_VERSION=${{ parameters.torch_version }} - --build-arg OPSET_VERSION=${{ parameters.opset_version }} - --build-arg PYTHON_VERSION=${{ parameters.python_version }} - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu - --build-arg BUILD_UID=$(id -u) - Repository: $(Repository) - - - task: CmdLine@2 - displayName: 'build onnxruntime' - inputs: - script: | - set -e -x - mkdir -p $HOME/.onnx - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - -e NIGHTLY_BUILD \ - -e DEFAULT_TRAINING_PACKAGE_DEVICE \ - -e BUILD_BUILDNUMBER \ - -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \ - $(Repository) \ - $(PythonManylinuxDir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build \ - --config ${{ variables['buildConfig'] }} \ - --skip_submodule_sync \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --enable_onnx_tests \ - ${{ parameters.build_py_parameters }} \ - --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=${{ parameters.cmake_cuda_architectures }}' onnxruntime_BUILD_UNIT_TESTS=OFF \ - --use_cuda --cuda_version=${{ parameters.cuda_version }} --cuda_home=/usr/local/cuda-${{ parameters.cuda_version }} --cudnn_home=/usr/local/cuda-${{ parameters.cuda_version }}; - workingDirectory: $(Build.SourcesDirectory) - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - Contents: "${{ variables['buildConfig'] }}/dist/*.whl" - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel and documentation' - inputs: - ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}" - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - - template: clean-agent-build-directory-step.yml - - - stage: Test_${{ parameters.stage_name }} - variables: - - name: isMain - value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - - name: buildConfig - ${{ if eq(parameters['debug_build'], 'true') }}: - value: 'Debug' - ${{ else }}: - value: 'Release' - - name: PythonVersion - value: ${{ parameters.python_version }} - - name: Repository - value: onnxruntimetraininggpubuild_cu${{ replace(parameters.cuda_version, '.', '') }}_py${{ replace(parameters.python_version, '.', '') }} - - name: UploadWheel - value: ${{ parameters.upload_wheel }} - dependsOn: Build_${{ parameters.stage_name }} - jobs: - - job: Test_GPU - pool: Onnxruntime-Linux-GPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: jobs/download_training_test_data.yml - - - template: set-python-manylinux-variables-step.yml - - - template: flex-downloadPipelineArtifact.yml - parameters: - ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}" - StepName: 'Download Pipeline Artifact - Linux Training Build' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - script: | - set -e -x - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - echo $whlfilename ; du -sh $whlfilename ; \ - (( $(wc -c < "$whlfilename") - 400*1024*1024 < 0 )) || ( echo 'Wheel size bigger than 400M'; exit 1) - displayName: 'Check wheel size' - continueOnError: true - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }} - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg TORCH_VERSION=${{ parameters.torch_version }} - --build-arg OPSET_VERSION=${{ parameters.opset_version }} - --build-arg PYTHON_VERSION=${{ parameters.python_version }} - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu - --build-arg BUILD_UID=$(id -u) - Repository: $(Repository) - - - task: CmdLine@2 - displayName: 'test ortmodule' - inputs: - script: | - set -ex ; \ - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - echo $whlfilename ; \ - basefilename=$(basename $whlfilename) ; \ - docker run --rm \ - --gpus all \ - -e NVIDIA_VISIBLE_DEVICES=all \ - --volume $(Build.ArtifactStagingDirectory):/build \ - --volume $(Agent.TempDirectory)/MNIST:/mnist \ - $(Repository) \ - bash -c " $(PythonManylinuxDir)/bin/python3 -m pip install /build/Release/dist/$basefilename && $(PythonManylinuxDir)/bin/python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install " ; - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 'Upload wheel' - condition: and(succeeded(), and(eq(variables['UploadWheel'], 'yes'), ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true'))) - inputs: - script: | - set -e -x - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - python3 tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path $whlfilename ${{ variables['finalStorage'] }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 6e573d79e4a72..e07f0afa6109c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -59,6 +59,11 @@ jobs: addToPath: true architecture: 'arm64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 2c9218a059e0c..8cc647c2464f3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -50,6 +50,11 @@ jobs: addToPath: true architecture: 'x64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 9cb82d65bcdce..466fee92d0d5e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -50,6 +50,11 @@ jobs: addToPath: true architecture: 'x64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 6fed0192d866d..aa0b6bf6d391e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,6 +1,6 @@ parameters: - QnnSdk: '2.26.0.240828' - build_config: 'RelWithDebInfo' + QnnSdk: '2.28.2.241116' + build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' @@ -32,9 +32,9 @@ stages: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true - + - template: jobs/download_win_qnn_sdk.yml parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} @@ -44,7 +44,7 @@ stages: inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - + - task: VSBuild@1 displayName: 'Build onnxruntime' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8593aa2d821fa..29c5f6bb34d7a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -23,7 +23,7 @@ parameters: displayName: 'Stage that the initial stage of react-native-ci depends on' type: string default: '' - + - name: enable_code_sign displayName: Use GPG to sign the jars type: boolean @@ -58,9 +58,9 @@ stages: steps: - template: use-xcode-version.yml - task: UsePythonVersion@0 - displayName: Use python 3.9 + displayName: Use python 3.12 inputs: - versionSpec: "3.9" + versionSpec: "3.12" addToPath: true architecture: "x64" @@ -113,9 +113,9 @@ stages: condition: always() - template: use-xcode-version.yml - task: UsePythonVersion@0 - displayName: Use python 3.9 + displayName: Use python 3.12 inputs: - versionSpec: "3.9" + versionSpec: "3.12" addToPath: true architecture: "x64" @@ -128,7 +128,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - script: brew install coreutils ninja npm yarn @@ -261,8 +261,6 @@ stages: publishJUnitResults: true testResultsFiles: '**/TEST-*.xml' testRunTitle: 'React Native Android Instrumented Test results' - javaHomeOption: 'path' - jdkDirectory: '$(JAVA_HOME_11_X64)' sonarQubeRunAnalysis: false spotBugsAnalysis: false displayName: Run React Native Android Instrumented Tests diff --git a/tools/ci_build/github/azure-pipelines/templates/set-python-manylinux-variables-step.yml b/tools/ci_build/github/azure-pipelines/templates/set-python-manylinux-variables-step.yml index 68836117db81d..585a78c17e042 100644 --- a/tools/ci_build/github/azure-pipelines/templates/set-python-manylinux-variables-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/set-python-manylinux-variables-step.yml @@ -35,6 +35,10 @@ steps: variables = { "PythonManylinuxDir": "/opt/python/cp312-cp312" } + elif version == "3.13": + variables = { + "PythonManylinuxDir": "/opt/python/cp313-cp313" + } else: raise ValueError("Unsupported Python version: '{}'".format(version)) diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index e27de27036130..5d7ea5e7b2727 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -18,7 +18,12 @@ stages: vmImage: "macOS-13" variables: + # Note: Keep the Xcode version and iOS simulator version compatible. + # Check the table here to see what iOS simulator versions are supported by a particular Xcode version: + # https://developer.apple.com/support/xcode/ xcodeVersion: "14.3.1" + iosSimulatorRuntimeVersion: "16.4" + ortPodVersion: $[stageDependencies.IosPackaging_SetCommonVariables.j.outputs['SetCommonVariables.ORT_POD_VERSION']] ${{ if eq(parameters.packageVariant, 'Full') }}: @@ -57,13 +62,13 @@ stages: - task: UsePythonVersion@0 inputs: - versionSpec: "3.9" + versionSpec: "3.12" addToPath: true architecture: "x64" - template: ../use-xcode-version.yml - - - template: ../install-appcenter.yml + parameters: + xcodeVersion: $(xcodeVersion) - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt @@ -80,6 +85,8 @@ stages: --build-settings-file "${{ variables.buildSettingsFile }}" \ ${{ variables.optionalIncludeOpsByConfigOption }} displayName: "Build macOS/iOS framework and assemble pod package files" + env: + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: $(iosSimulatorRuntimeVersion) - script: | python tools/ci_build/github/apple/test_apple_packages.py \ @@ -91,6 +98,8 @@ stages: --prepare_test_project_only displayName: "Assemble test project for App Center" + # Xcode tasks require absolute paths because it searches for the paths and files relative to + # the root directory and not relative to the working directory - task: Xcode@5 inputs: actions: 'build-for-testing' @@ -98,8 +107,6 @@ stages: xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/apple_package_test.xcworkspace' sdk: 'iphoneos' scheme: 'ios_package_test' - xcodeVersion: 'specifyPath' - xcodeDeveloperDir: '/Applications/Xcode_${{ variables.xcodeVersion }}.app/Contents/Developer' signingOption: 'manual' signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' provisioningProfileUuid: '$(APPLE_PROV_PROFILE_UUID)' @@ -108,16 +115,69 @@ stages: useXcpretty: false # xcpretty can hide useful error output so we will disable it displayName: 'Build App Center iPhone arm64 tests' + - script: | + zip -r --symlinks $(Build.ArtifactStagingDirectory)/package_tests.zip ios_package_testUITests-Runner.app + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData/Build/Products/Debug-iphoneos' + displayName: "Create .zip file of the tests" + + - script: | + python $(Build.SourcesDirectory)/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py \ + --dest_file "exportOptions.plist" \ + --apple_team_id $(APPLE_TEAM_ID) \ + --provisioning_profile_uuid $(APPLE_PROV_PROFILE_UUID) + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + displayName: "Generate .plist file for the .ipa file" + + # Task only generates an .xcarchive file if the plist export options are included, but does + # not produce an IPA file. + # Source code: https://github.com/microsoft/azure-pipelines-tasks/blob/master/Tasks/XcodeV5/xcode.ts + - task: Xcode@5 + inputs: + actions: 'archive' + xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/apple_package_test.xcworkspace' + packageApp: true + archivePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + exportOptions: 'plist' + exportOptionsPlist: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/exportOptions.plist' + configuration: 'Debug' + sdk: 'iphoneos' + scheme: 'ios_package_test' + args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData' + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + useXcpretty: false + displayName: 'Create archive for the .ipa file' + + # Use script step because exporting the .ipa file using the Xcode@5 task was too brittle (Xcode@5 is designed + # to handle both the .xcarchive step and the .ipa step in the same step -- ran into countless issues with signing + # and the .plist file) + - script: | + xcodebuild -exportArchive \ + -archivePath ios_package_test.xcarchive \ + -exportOptionsPlist exportOptions.plist \ + -exportPath $(Build.ArtifactStagingDirectory)/test_ipa + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + displayName: "Create .ipa file" + + # Publish the BrowserStack artifacts first so that if the next step fails, the artifacts will still be published + # so that users can attempt to locally debug + - publish: "$(Build.ArtifactStagingDirectory)" + artifact: "browserstack_test_artifacts_${{ lower(parameters.packageVariant) }}" + displayName: "Publish BrowserStack test artifacts" + - script: | set -e -x - appcenter test run xcuitest \ - --app "AI-Frameworks/ORT-Mobile-iOS" \ - --devices $(app_center_test_devices) \ - --test-series "master" \ - --locale "en_US" \ - --build-dir $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData/Build/Products/Debug-iphoneos \ - --token $(app_center_api_token) - displayName: "Run E2E tests on App Center" + pip install requests + python $(Build.SourcesDirectory)/tools/python/upload_and_run_browserstack_tests.py \ + --test_platform xcuitest \ + --app_path "$(Build.ArtifactStagingDirectory)/test_ipa/ios_package_test.ipa" \ + --test_path "$(Build.ArtifactStagingDirectory)/package_tests.zip" \ + --devices "iPhone 15-17" + displayName: Run E2E tests using Browserstack + workingDirectory: $(Build.BinariesDirectory)/app_center_test/apple_package_test + timeoutInMinutes: 15 + env: + BROWSERSTACK_ID: $(browserstack_username) + BROWSERSTACK_TOKEN: $(browserstack_access_key) - script: | set -e -x diff --git a/tools/ci_build/github/azure-pipelines/templates/validate-package.yml b/tools/ci_build/github/azure-pipelines/templates/validate-package.yml index 5014b315a4083..529cca4586ef6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/validate-package.yml +++ b/tools/ci_build/github/azure-pipelines/templates/validate-package.yml @@ -11,11 +11,11 @@ steps: - task: UsePythonVersion@0 displayName: 'Use Python' inputs: - versionSpec: 3.8 + versionSpec: 3.12 - task: PythonScript@0 displayName: 'Validate Package' inputs: scriptPath: '${{parameters.ScriptPath}}' arguments: '--package_type ${{parameters.PackageType}} --package_name ${{parameters.PackageName}} --package_path ${{parameters.PackagePath}} --platforms_supported ${{parameters.PlatformsSupported}} --verify_nuget_signing ${{parameters.VerifyNugetSigning}}' - workingDirectory: ${{parameters.workingDirectory}} + workingDirectory: ${{parameters.workingDirectory}} diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index b9a184c8d9bcf..8fd532e73c114 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index b2e1833156657..01a3b75ed958c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -73,7 +73,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 52547fd9a796b..59950433b3d40 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -118,25 +118,33 @@ stages: clean: true submodules: none + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: ${{ parameters.buildArch }} + - template: telemetry-steps.yml + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.8' - addToPath: true - architecture: ${{ parameters.buildArch }} - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '18.x' + versionSpec: '20.x' - ${{ if ne(parameters.CudaVersion, '') }}: - template: jobs/download_win_gpu_library.yml @@ -354,20 +362,27 @@ stages: - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: ${{ parameters.buildArch }} + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '18.x' + versionSpec: '20.x' - ${{ if ne(parameters.CudaVersion, '') }}: - template: jobs/download_win_gpu_library.yml @@ -397,4 +412,4 @@ stages: parameters: msbuildPlatform: ${{ parameters.msbuildPlatform }} java_artifact_id: ${{ parameters.java_artifact_id }} - buildOnly: false \ No newline at end of file + buildOnly: false diff --git a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml index 8a386963a89dd..86acebc9f7a71 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml @@ -67,6 +67,7 @@ steps: - task: PowerShell@2 displayName: 'Signature validation for signed file(s)' + condition: and(succeeded(), eq('${{ parameters.DoEsrp }}', true)) inputs: targetType: 'inline' script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 64e7b6dbb4455..2b9d2b77f1e6b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -76,12 +76,12 @@ jobs: displayName: 'Checkout submodules' - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 0e8a7eb94379b..8aa73386b8d7d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -74,7 +74,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 436d914c426ad..6b6e4a869a0d2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -12,6 +12,9 @@ jobs: workspace: clean: all steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - checkout: self submodules: false - task: DownloadPipelineArtifact@2 @@ -34,7 +37,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' @@ -60,6 +63,14 @@ jobs: npm ci workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm ci /js/web/' + - task: Cache@2 + inputs: + key: onnxtestdata | $(Build.SourcesDirectory)\js\scripts\prepare-onnx-node-tests.ts + restoreKeys: | + onnxtestdata | $(Build.SourcesDirectory)\js\scripts\prepare-onnx-node-tests.ts + path: $(Build.SourcesDirectory)/js/test/ + cacheHitVar: CACHE_RESTORED + displayName: 'Cache ONNX node test data' - script: | powershell "Get-WmiObject Win32_Process -Filter \"name = 'chrome.exe'\" | Format-List CommandLine" displayName: 'Check active Chrome processes (before test)' @@ -87,6 +98,3 @@ jobs: npm test -- suite0 -b=wasm,webgl -e=edge --wasm.initTimeout=30000 --file-cache --user-data-dir=$(Agent.TempDirectory)\web\test_multi_browsers\03 workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm test (Suite0, Edge)' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml new file mode 100644 index 0000000000000..fb3ebdc760a7b --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -0,0 +1,183 @@ +parameters: +- name: BuildArch + displayName: BuildArch + type: string + default: 'x64' + +- name: Runtime + displayName: MSVC Runtime, should be 'dynamic' or 'static'. + type: string + default: 'dynamic' + +jobs: +- job: Windows_Packaging_${{ parameters.BuildArch }}_${{ parameters.Runtime }} + timeoutInMinutes: 180 + templateContext: + outputs: + - output: pipelineArtifact + path: '$(Build.ArtifactStagingDirectory)' + artifact: drop_Windows_Build_Windows_Packaging_${{ parameters.BuildArch }}_${{ parameters.Runtime }} + + steps: + - task: UseDotNet@2 + inputs: + version: '6.x' + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + ${{ if eq(parameters.BuildArch, 'x86') }}: + architecture: 'x86' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - template: telemetry-steps.yml + + - task: NuGetCommand@2 + displayName: 'NuGet restore' + inputs: + command: restore + feedsToUse: config + nugetConfigPath: $(Build.SourcesDirectory)\tools\ci_build\github\azure-pipelines\nuget\nuget_config\nuget.config + restoreDirectory: '$(Build.BinariesDirectory)' + ${{ if eq(parameters.BuildArch, 'x64') }}: + restoreSolution: $(Build.SourcesDirectory)\tools\ci_build\github\azure-pipelines\nuget\nuget_config\x64\packages.config + ${{ if eq(parameters.BuildArch, 'x86') }}: + restoreSolution: $(Build.SourcesDirectory)\.tools\ci_build\github\azure-pipelines\nuget\nuget_config\x86\packages.config + ${{ if eq(parameters.BuildArch, 'arm') }}: + restoreSolution: $(Build.SourcesDirectory)\tools\ci_build\github\azure-pipelines\nuget\nuget_config\x64\packages.config + ${{ if eq(parameters.BuildArch, 'arm64') }}: + restoreSolution: $(Build.SourcesDirectory)\tools\ci_build\github\azure-pipelines\nuget\nuget_config\x64\packages.config + + - script: | + @echo off + set vswherepath="%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" + for /f "usebackq delims=" %%i in (`%vswherepath% -latest -property installationPath`) do ( + set vslatest="%%i" + if exist "%%i\Common7\Tools\vsdevcmd.bat" ( + set vsdevcmd="%%i\Common7\Tools\vsdevcmd.bat" + ) + ) + + @echo vslatest %vslatest% + @echo vsdevcmd %vsdevcmd% + + @echo ##vso[task.setvariable variable=vslatest]%vslatest% + @echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd% -arch=${{ parameters.BuildArch }} + displayName: 'locate vsdevcmd via vswhere' + + - powershell: | + Write-Host "##vso[task.setvariable variable=BuildFlags]" + Write-Host "##vso[task.setvariable variable=ArtifactName]Microsoft.AI.MachineLearning.${{ parameters.BuildArch }}" + displayName: Initialize build flags + + - powershell: | + Write-Host "##vso[task.setvariable variable=BuildFlags]$(BuildFlags) --${{ parameters.BuildArch }}" + displayName: Add cross compilation flags for ARM + condition: and(ne('${{ parameters.BuildArch }}', 'x64'), ne('${{ parameters.BuildArch }}', 'x86')) + + - powershell: | + Write-Host "##vso[task.setvariable variable=BuildFlags]$(BuildFlags) --enable_msvc_static_runtime" + Write-Host "##vso[task.setvariable variable=ArtifactName]$(ArtifactName).StaticRuntime" + displayName: Add static runtime flags + condition: eq('${{ parameters.Runtime }}', 'static') + + # must call vsdevcmd first to add cmake to PATH + - script: | + python --version + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Generate cmake config' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + ${{ if ne(parameters.BuildArch, 'x86') }}: + platform: ${{ parameters.BuildArch }} + ${{ if eq(parameters.BuildArch, 'x86') }}: + platform: 'Win32' + configuration: RelWithDebInfo + msbuildArchitecture: ${{ parameters.BuildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + - ${{ if eq(parameters.Runtime, 'dynamic') }}: + - script: | + xcopy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\winml_test_api.exe $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\winml_test_scenario.exe $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\api\models\*.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\scenario\cppwinrt\*.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\scenario\models\*.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\common\testdata\squeezenet\* $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\collateral\models\*.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + xcopy $(Build.SourcesDirectory)\winml\test\collateral\models\ModelSubdirectory $(Build.ArtifactStagingDirectory)\test_artifact\ModelSubdirectory\ /i + copy $(Build.SourcesDirectory)\winml\test\collateral\images\*.png $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\winml\test\collateral\images\*.jpg $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\onnxruntime\test\testdata\sequence_length.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + copy $(Build.SourcesDirectory)\onnxruntime\test\testdata\sequence_construct.onnx $(Build.ArtifactStagingDirectory)\test_artifact\ + displayName: 'Copy WinML test collateral to artifact directory' + + + - ${{ if eq(parameters.BuildArch, 'x64') }}: + - script: | + call $(vsdevcmd) + msbuild Microsoft.AI.MachineLearning.Interop.csproj /p:Configuration=RelWithDebInfo /p:Platform="Any CPU" /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory) -restore + workingDirectory: '$(Build.SourcesDirectory)\csharp\src\Microsoft.AI.MachineLearning.Interop' + displayName: 'Build Microsoft.AI.MachineLearning.Interop.dll' + + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + DisplayName: 'Sign runtime DLLs' + Pattern: '*.exe,*.dll' + + - ${{ if eq(parameters.BuildArch, 'x64') }}: + - script: | + call $(vsdevcmd) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreateWindowsAIPackage /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory) /p:OnnxRuntimeSourceDirectory=$(Build.SourcesDirectory) + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.snupkg $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + displayName: 'Create NuGet Package' + + - ${{ if eq(parameters.BuildArch, 'x86') }}: + - script: | + call $(vsdevcmd) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreateWindowsAIPackage /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory) /p:OnnxRuntimeSourceDirectory=$(Build.SourcesDirectory) /p:TargetArchitecture=x86 + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.snupkg $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + displayName: 'Create NuGet Package' + + - ${{ if eq(parameters.BuildArch, 'arm64') }}: + - script: | + call $(vsdevcmd) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreateWindowsAIPackage /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory) /p:OnnxRuntimeSourceDirectory=$(Build.SourcesDirectory) /p:TargetArchitecture=arm64 /p:ProtocDirectory=$(Build.BinariesDirectory)\host_protoc\Release + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.snupkg $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + displayName: 'Create NuGet Package' + + - ${{ if eq(parameters.BuildArch, 'arm') }}: + - script: | + call $(vsdevcmd) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreateWindowsAIPackage /p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory) /p:OnnxRuntimeSourceDirectory=$(Build.SourcesDirectory) /p:TargetArchitecture=arm /p:ProtocDirectory=$(Build.BinariesDirectory)\host_protoc\Release + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) + copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.snupkg $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + displayName: 'Create NuGet Package' + + # Only dynamic copied to test_artifact + - ${{ if eq(parameters.Runtime, 'dynamic') }}: + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.ArtifactStagingDirectory)\test_artifact' + DisplayName: 'Sign test_artifact' + Pattern: '*.exe,*.dll' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index b5120f01bff3e..0ab6b08662308 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -36,7 +36,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '18.x' + versionSpec: '20.x' - task: NuGetToolInstaller@0 displayName: Use Nuget 6.10.x diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 7c04d6aa2e739..f4ab9ee5b4a5c 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -37,11 +37,12 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: win_trt_folder ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: ${{ variables.win_trt_folder_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 + value: ${{ variables.win_trt_folder_cuda12 }} jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index c4db7735aaf2f..06f374afca57a 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -41,10 +41,11 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat + EnvSetupScript: setup_env.bat buildArch: x64 - # add --enable_pybind and --build_java if necessary + # add --build_java if necessary additionalBuildFlags: >- + --enable_pybind --build_nodejs --use_webgpu --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON @@ -56,3 +57,52 @@ stages: EnablePython: false WITH_CACHE: true MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + +- stage: webgpu_external_dawn + dependsOn: [] + jobs: + - job: build_x64_RelWithDebInfo + variables: + DEPS_CACHE_DIR: $(Agent.TempDirectory)/deps_ccache + ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Win2022-VS2022-webgpu-A10 + timeoutInMinutes: 300 + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/jobs/win-ci-prebuild-steps.yml + parameters: + EnvSetupScript: setup_env.bat + DownloadCUDA: false + DownloadTRT: false + BuildArch: x64 + BuildConfig: RelWithDebInfo + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + WithCache: true + Today: $(Today) + + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: true + Today: $(TODAY) + CacheDir: $(ORT_CACHE_DIR) + AdditionalKey: " $(System.StageName) | RelWithDebInfo " + BuildPyArguments: '--config RelWithDebInfo --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --update --parallel --cmake_generator "Visual Studio 17 2022" --use_webgpu --use_external_dawn --skip_tests --target onnxruntime_webgpu_external_dawn_test' + MsbuildArguments: '-maxcpucount' + BuildArch: x64 + Platform: x64 + BuildConfig: RelWithDebInfo + + - script: | + onnxruntime_webgpu_external_dawn_test.exe + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + - script: | + onnxruntime_webgpu_external_dawn_test.exe --no_proc_table + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 44f5235e70c9f..8b5a5ecc13a44 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index bb448e848e499..978b14b76541c 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.26.0.240828 + default: 2.28.2.241116 jobs: - job: 'build' @@ -54,7 +54,7 @@ jobs: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8' + versionSpec: '3.12' addToPath: true architecture: $(buildArch) diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index 57a3bedc1e8e4..9922fc396b3d5 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -2,4 +2,4 @@ set -e -x docker run --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \ -/bin/bash -c "/usr/bin/python3.9 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3.12 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/build_cuda_ci.sh b/tools/ci_build/github/linux/build_cuda_ci.sh index c8691b3a01e70..0533b7b394492 100755 --- a/tools/ci_build/github/linux/build_cuda_ci.sh +++ b/tools/ci_build/github/linux/build_cuda_ci.sh @@ -3,28 +3,31 @@ set -ex #Every cuda container has this $CUDA_VERSION env var set. SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') -BUILD_ARGS=('--config' 'Release' '--update' '--build' - '--skip_submodule_sync' - '--build_shared_lib' - '--parallel' '--use_binskim_compliant_compile_flags' - '--build_wheel' - '--enable_onnx_tests' - '--use_cuda' - "--cuda_version=$SHORT_CUDA_VERSION" - "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" - "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" - "--enable_cuda_profiling" - "--enable_cuda_nhwc_ops" - "--enable_pybind" - "--build_java" - "--cmake_extra_defines" - "CMAKE_CUDA_ARCHITECTURES=75" - "onnxruntime_BUILD_UNIT_TESTS=ON" - "onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON") +BUILD_ARGS=('--config' + 'Release' + '--update' + '--build' + '--skip_submodule_sync' + '--build_shared_lib' + '--parallel' + '--use_binskim_compliant_compile_flags' + '--build_wheel' + '--enable_onnx_tests' + '--use_cuda' + "--cuda_version=$SHORT_CUDA_VERSION" + "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" + "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" + "--enable_cuda_profiling" + "--enable_pybind" + "--build_java" + "--cmake_extra_defines" + "CMAKE_CUDA_ARCHITECTURES=75" + "onnxruntime_BUILD_UNIT_TESTS=ON" + "onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON") if [ -x "$(command -v ninja)" ]; then BUILD_ARGS+=('--cmake_generator' 'Ninja') fi - + if [ -d /build ]; then BUILD_ARGS+=('--build_dir' '/build') else @@ -33,14 +36,14 @@ fi if [ -x "$(command -v ccache)" ]; then ccache -s; - BUILD_ARGS+=("--use_cache") + #BUILD_ARGS+=("--use_cache") fi -if [ -f /opt/python/cp38-cp38/bin/python3 ]; then - /opt/python/cp38-cp38/bin/python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" +if [ -f /opt/python/cp312-cp312/bin/python3 ]; then + /opt/python/cp312-cp312/bin/python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" else python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" fi -if [ -x "$(command -v ccache)" ]; then - ccache -sv +if [ -x "$(command -v ccache)" ]; then + ccache -sv ccache -z fi diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index 568d9a74d75d1..e2e0cea69efb5 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -6,18 +6,33 @@ set -e -x mkdir -p /build/dist EXTRA_ARG="" - -# Put 3.10 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this +ENABLE_CACHE=false +# Put 3.10 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this # config to Azure DevOps Artifacts and download them to a Ubuntu 22.04 machine to run the tests. -PYTHON_EXES=("/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp310-cp310/bin/python3.10") -while getopts "d:p:x:c:" parameter_Option +PYTHON_EXES=( + "/opt/python/cp311-cp311/bin/python3.11" + "/opt/python/cp312-cp312/bin/python3.12" + "/opt/python/cp313-cp313/bin/python3.13" + "/opt/python/cp313-cp313t/bin/python3.13t" + "/opt/python/cp310-cp310/bin/python3.10" + ) +while getopts "d:p:x:c:e" parameter_Option do case "${parameter_Option}" in #GPU|CPU|NPU. d) BUILD_DEVICE=${OPTARG};; -p) PYTHON_EXES=${OPTARG};; +p) + # Check if OPTARG is empty or starts with a hyphen, indicating a missing or invalid argument for -p + if [[ -z "${OPTARG}" || "${OPTARG}" == -* ]]; then + echo "ERROR: Option -p requires a valid argument, not another option." + exit 1 + else + PYTHON_EXES=("${OPTARG}") # Use the provided argument for -p + fi + ;; x) EXTRA_ARG=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +e) ENABLE_CACHE=true;; *) echo "Usage: $0 -d [-p ] [-x ] [-c ]" exit 1;; esac @@ -26,17 +41,39 @@ done BUILD_ARGS=("--build_dir" "/build" "--config" "$BUILD_CONFIG" "--update" "--build" "--skip_submodule_sync" "--parallel" "--use_binskim_compliant_compile_flags" "--build_wheel") -if [[ "$EXTRA_ARG" == *"training"* ]]; then - echo "Skip building unit tests because the container is a manylinux docker" - BUILD_ARGS+=("--cmake_extra_defines" "onnxruntime_BUILD_UNIT_TESTS=OFF") -fi if [ "$BUILD_CONFIG" != "Debug" ]; then BUILD_ARGS+=("--enable_lto") fi +if [ "$ENABLE_CACHE" = true ] ; then + BUILD_ARGS+=("--use_cache") + # No release binary for ccache aarch64, so we need to build it from source. + if ! [ -x "$(command -v ccache)" ]; then + ccache_url="https://github.com/ccache/ccache/archive/refs/tags/v4.8.tar.gz" + cd /build + curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o ccache_src.tar.gz $ccache_url + mkdir ccache_main + cd ccache_main + tar -zxf ../ccache_src.tar.gz --strip=1 + + mkdir build + cd build + cmake -DCMAKE_INSTALL_PREFIX=/build -DCMAKE_BUILD_TYPE=Release .. + make -j$(nproc) + make install + export PATH=/build/bin:$PATH + which ccache + rm -f ccache_src.tar.gz + rm -rf ccache_src + fi + ccache -s; +fi ARCH=$(uname -m) + + + echo "EXTRA_ARG:" echo "$EXTRA_ARG" @@ -52,7 +89,7 @@ fi if [ "$BUILD_DEVICE" == "GPU" ]; then SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi if [ "$BUILD_DEVICE" == "NPU" ]; then @@ -60,12 +97,20 @@ if [ "$BUILD_DEVICE" == "NPU" ]; then BUILD_ARGS+=("--use_qnn" "--qnn_home=/qnn_sdk") fi +export ONNX_ML=1 +export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" + for PYTHON_EXE in "${PYTHON_EXES[@]}" do rm -rf /build/"$BUILD_CONFIG" - ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" - + # that's a workaround for the issue that there's no python3 in the docker image + # like xnnpack's cmakefile, it uses pythone3 to run a external command + python3_dir=$(dirname "$PYTHON_EXE") + ${PYTHON_EXE} -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt + PATH=$python3_dir:$PATH ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist done -which ccache && ccache -sv && ccache -z +if [ "$ENABLE_CACHE" = true ] ; then + which ccache && ccache -sv && ccache -z +fi diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh index f0c9d51a53448..7f18e2f849d27 100755 --- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh @@ -3,4 +3,4 @@ set -e -x mkdir -p $HOME/.onnx docker run --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ -/bin/bash -c "/usr/bin/python3.9 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3.12 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/build_tensorrt_ci.sh b/tools/ci_build/github/linux/build_tensorrt_ci.sh index 3002f2c239f1a..5b206bc0a92d9 100755 --- a/tools/ci_build/github/linux/build_tensorrt_ci.sh +++ b/tools/ci_build/github/linux/build_tensorrt_ci.sh @@ -35,8 +35,8 @@ if [ -x "$(command -v ccache)" ]; then ccache -s; BUILD_ARGS+=("--use_cache") fi -if [ -f /opt/python/cp38-cp38/bin/python3 ]; then - /opt/python/cp38-cp38/bin/python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" +if [ -f /opt/python/cp312-cp312/bin/python3 ]; then + /opt/python/cp312-cp312/bin/python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" else python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" fi diff --git a/tools/ci_build/github/linux/docker/Dockerfile.aten_cpu b/tools/ci_build/github/linux/docker/Dockerfile.aten_cpu deleted file mode 100644 index 16cf0dfa4f777..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.aten_cpu +++ /dev/null @@ -1,10 +0,0 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241008.1 - -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps_aten.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.lort_cpu b/tools/ci_build/github/linux/docker/Dockerfile.lort_cpu deleted file mode 100644 index 04b535e49548c..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.lort_cpu +++ /dev/null @@ -1,10 +0,0 @@ -FROM registry.access.redhat.com/ubi8/ubi - -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps_lort.sh && rm -rf /tmp/scripts -ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin -ARG BUILD_UID=1002 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index b517af75d2c91..d2d3aa1675c2e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,6 +1,6 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241008.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241120.3 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 07885ba65af8a..c42042b0ec639 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -32,7 +32,7 @@ else \ echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ fi -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts @@ -42,5 +42,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING "LAZY" \ No newline at end of file +ENV PATH=/usr/local/dotnet:$PATH +ENV CUDA_MODULE_LOADING="LAZY" \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index e1914d5fe2f06..9a265b4249f0b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -6,7 +6,7 @@ ARG LD_LIBRARY_PATH_ARG=${DEVTOOLSET_ROOTPATH}/usr/lib64:${DEVTOOLSET_ROOTPATH}/ ARG PREPEND_PATH=${DEVTOOLSET_ROOTPATH}/usr/bin: FROM $BASEIMAGE AS base_image -ARG ROCM_VERSION=5.5 +ARG ROCM_VERSION=6.2.3 #Add our own dependencies ADD scripts /tmp/scripts @@ -185,8 +185,6 @@ ARG INSTALL_DEPS_EXTRA_ARGS RUN cd /tmp/scripts && \ /tmp/scripts/manylinux/install_centos.sh && \ /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS && \ - /tmp/scripts/install_python_deps.sh -d gpu -p 3.8 $INSTALL_DEPS_EXTRA_ARGS && \ - /tmp/scripts/install_python_deps.sh -d gpu -p 3.9 $INSTALL_DEPS_EXTRA_ARGS && \ /tmp/scripts/install_python_deps.sh -d gpu -p 3.10 $INSTALL_DEPS_EXTRA_ARGS && \ rm -rf /tmp/scripts @@ -203,5 +201,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH +ENV PATH=/usr/local/dotnet:$PATH ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 deleted file mode 100644 index 9134109237930..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ /dev/null @@ -1,26 +0,0 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241008.1 -ARG PYTHON_VERSION=3.9 -ARG TORCH_VERSION=2.0.0 -ARG OPSET_VERSION=17 -ARG INSTALL_DEPS_EXTRA_ARGS - -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && \ - /tmp/scripts/manylinux/install_centos.sh && \ - /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS && \ - /tmp/scripts/install_rust.sh - -ENV PATH="/root/.cargo/bin/:$PATH" - -RUN /tmp/scripts/install_ninja.sh && \ - /tmp/scripts/install_python_deps.sh -d gpu -v 11.8 -p $PYTHON_VERSION -h $TORCH_VERSION $INSTALL_DEPS_EXTRA_ARGS && \ - rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 deleted file mode 100644 index 1bea0df1fc2cf..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 +++ /dev/null @@ -1,27 +0,0 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241008.1 - -ARG PYTHON_VERSION=3.9 -ARG TORCH_VERSION=2.1.0 -ARG OPSET_VERSION=17 -ARG INSTALL_DEPS_EXTRA_ARGS - -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && \ - /tmp/scripts/manylinux/install_centos.sh && \ - /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS && \ - /tmp/scripts/install_rust.sh - -ENV PATH="/root/.cargo/bin/:$PATH" - -RUN /tmp/scripts/install_ninja.sh && \ - /tmp/scripts/install_python_deps.sh -d gpu -v 12.2 -p $PYTHON_VERSION -h $TORCH_VERSION $INSTALL_DEPS_EXTRA_ARGS && \ - rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index 8ef8e05b8ac77..9de88d1664b82 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -9,19 +9,19 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG TRT_VERSION=8.6.1.6-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache # Install python3 RUN dnf install -y \ - python3.8 \ - python38-pip \ - python38-wheel &&\ + python3.10 \ + python310-pip \ + python310-wheel &&\ cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python3.8 &&\ - ln -s /usr/bin/pip3 pip3.8; + ln -s /usr/bin/python3 python3.10 &&\ + ln -s /usr/bin/pip3 pip3.10; RUN pip3 install --upgrade pip RUN pip3 install setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index c1a445e29fc89..c2bae5fd7ee59 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -6,10 +6,10 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 -ARG TRT_VERSION=10.4.0.26-1.cuda12.6 +ARG TRT_VERSION=10.6.0.26-1.cuda12.6 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch index a228ebed165eb..2ecc6d1918b1a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch @@ -6,10 +6,10 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.4.0.26-1.cuda11.8 +ARG TRT_VERSION=10.6.0.26-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 6a4244b7aad0d..81aeada6a4a46 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -6,11 +6,11 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.4.0.26-1+cuda11.8 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg similarity index 93% rename from tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg rename to tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg index 418c551ab38b4..4298dd53e4c66 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg @@ -5,12 +5,12 @@ # Dockerfile to run ONNXRuntime with TensorRT integration # Build base image with required system packages -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.4.0.26-1+cuda11.8 +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv new file mode 100644 index 0000000000000..1312475ceca3a --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv @@ -0,0 +1,64 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with TensorRT integration + +# Build base image with required system packages +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 +ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 +FROM $BASEIMAGE AS base +ARG TRT_VERSION +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV DEBIAN_FRONTEND=noninteractive + +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} + +RUN apt-get update &&\ + apt-get install -y git bash wget diffutils + +RUN DEBIAN_FRONTEND="noninteractive" apt-get install --yes python3-opencv + +# Install python3 +RUN apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-wheel + +RUN pip install --upgrade pip + +# Install TensorRT +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ + apt-get update &&\ + apt-get install -y \ + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} &&\ + if [ $(echo $CUDA_VERSION | cut -d"." -f1) -ge 12 ]; then apt-get install -y cudnn9-cuda-12 ; fi +# ^^^^^^^^^^^If cuda version is 12 or higher, install cudnn 9 for cuda 12 + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index dfc057b129f91..3b4d36a9a8fd8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -82,7 +82,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 index a7d8f220ea9b3..22d5e3b0248a8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -31,26 +31,26 @@ RUN pip install --upgrade pip RUN pip install psutil setuptools>=68.2.2 # Install TensorRT -RUN version="10.4.0.26-1+cuda11.8" &&\ +RUN TRT_VERSION="10.6.0.26-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ - libnvinfer-dev=${version} \ - libnvinfer-dispatch-dev=${version} \ - libnvinfer-dispatch10=${version} \ - libnvinfer-headers-dev=${version} \ - libnvinfer-headers-plugin-dev=${version} \ - libnvinfer-lean-dev=${version} \ - libnvinfer-lean10=${version} \ - libnvinfer-plugin-dev=${version} \ - libnvinfer-plugin10=${version} \ - libnvinfer-vc-plugin-dev=${version} \ - libnvinfer-vc-plugin10=${version} \ - libnvinfer10=${version} \ - libnvonnxparsers-dev=${version} \ - libnvonnxparsers10=${version} \ - tensorrt-dev=${version} \ - libnvinfer-bin=${version} + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} # Compile trtexec if not installed RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; then \ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 index f63112039fe8e..6d35df72894d8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -85,7 +85,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 523318f09aba6..819d9bab7be75 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -31,26 +31,26 @@ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 psutil # Install TensorRT -RUN version="10.4.0.26-1+cuda12.6" &&\ +RUN TRT_VERSION="10.6.0.26-1+cuda12.6" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ - libnvinfer-dev=${version} \ - libnvinfer-dispatch-dev=${version} \ - libnvinfer-dispatch10=${version} \ - libnvinfer-headers-dev=${version} \ - libnvinfer-headers-plugin-dev=${version} \ - libnvinfer-lean-dev=${version} \ - libnvinfer-lean10=${version} \ - libnvinfer-plugin-dev=${version} \ - libnvinfer-plugin10=${version} \ - libnvinfer-vc-plugin-dev=${version} \ - libnvinfer-vc-plugin10=${version} \ - libnvinfer10=${version} \ - libnvonnxparsers-dev=${version} \ - libnvonnxparsers10=${version} \ - tensorrt-dev=${version} \ - libnvinfer-bin=${version} + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} # Compile trtexec if not installed RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; then \ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training deleted file mode 100644 index 4d11cbbde3354..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training +++ /dev/null @@ -1,60 +0,0 @@ -ARG BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu18.04 - -FROM $BASEIMAGE - -ARG PYTHON_VERSION=3.9 -ARG INSTALL_DEPS_EXTRA_ARGS -ARG USE_CONDA=false - -ADD scripts /tmp/scripts -RUN /tmp/scripts/install_ubuntu.sh -p $PYTHON_VERSION && \ - /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS - -# If USE_CONDA is false, use root to install python dependencies. -RUN if [ "$USE_CONDA" = false ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS ; \ - fi - -WORKDIR /root - -# Allow configure to pick up GDK and CuDNN where it expects it. -# (Note: $CUDNN_VERSION is defined by NVidia's base image) -RUN _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/include && \ - ln -s /usr/include/cudnn.h /usr/local/cudnn-$_CUDNN_VERSION/cuda/include/cudnn.h && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64 && \ - ln -s /etc/alternatives/libcudnn_so /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64/libcudnn.so && \ - ln -s /usr/local/cudnn{-$_CUDNN_VERSION,} - -ENV LD_LIBRARY_PATH /usr/local/openblas/lib:$LD_LIBRARY_PATH - -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -WORKDIR /home/$BUILD_USER -USER $BUILD_USER - -ARG MINICONDA_PREFIX=/home/$BUILD_USER/miniconda3 -RUN if [ "$USE_CONDA" = true ] ; \ - then MINICONDA=miniconda.sh && \ - wget --no-verbose https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O $MINICONDA && \ - chmod a+x $MINICONDA && \ - ./$MINICONDA -b -p $MINICONDA_PREFIX && \ - rm ./$MINICONDA && \ - $MINICONDA_PREFIX/bin/conda clean --yes --all && \ - $MINICONDA_PREFIX/bin/conda install -y python=$PYTHON_VERSION ; \ - fi - -ENV PATH /home/$BUILD_USER/miniconda3/bin:$PATH - -# If USE_CONDA is true, use onnxruntimedev user to install python dependencies -RUN if [ "$USE_CONDA" = true ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS -c ; \ - fi - -WORKDIR /root -USER root -RUN rm -rf /tmp/scripts - -WORKDIR /home/$BUILD_USER -USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 5f525c1310412..643c0d66d01f5 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,7 +1,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} -ARG OPENVINO_VERSION=2024.3.0 +ARG OPENVINO_VERSION=2024.5.0 ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts @@ -12,16 +12,16 @@ RUN /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d EdgeDevice RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \ rm -rf /var/lib/apt/lists/* /tmp/scripts -ENV INTEL_OPENVINO_DIR /opt/intel/openvino_${OPENVINO_VERSION} -ENV LD_LIBRARY_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH -ENV OpenVINO_DIR $INTEL_OPENVINO_DIR/runtime/cmake -ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64 +ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_${OPENVINO_VERSION} +ENV LD_LIBRARY_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH +ENV OpenVINO_DIR=$INTEL_OPENVINO_DIR/runtime/cmake +ENV IE_PLUGINS_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.3/linux/l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && \ - tar xzf l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && \ - mv l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64 openvino_2024.3.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.5/linux/l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ + tar xzf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ + mv l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64 openvino_2024.5.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index e8d8dc0a64feb..4f58dc89333ba 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -44,7 +44,7 @@ COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUD RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz RUN cd /TensorRT-${TAR_TRT_VERSION}/python &&\ - python3 -m pip install tensorrt*cp38*.whl + python3 -m pip install tensorrt*cp310*.whl RUN cp -r /TensorRT-${TAR_TRT_VERSION}/lib/* /usr/lib/x86_64-linux-gnu/ RUN cp /TensorRT-${TAR_TRT_VERSION}/include/* /usr/local/include/ @@ -92,7 +92,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index ca00050121d67..246ef09f7be25 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,15 +2,14 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=arm64v8/almalinux:8 -FROM $BASEIMAGE +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12_dotnet:20241120.3 ENV PATH=/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh deleted file mode 100755 index adb0464d6496a..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -set -e -x - -os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) - -echo "installing for CentOS version : $os_major_version" -dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index 9c3017240f77f..70bb373efb23f 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -39,9 +39,6 @@ mkdir -p /tmp/src cd /tmp/src CPU_ARCH=$(uname -m) -echo "Installing cmake" -GetFile "https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" -tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 05f290566b567..43dd3badef387 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241008.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241120.3 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh index f576b867da73b..81de2abf3ff87 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh @@ -1,14 +1,14 @@ #!/bin/bash set -e -x pushd . -PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12") +PYTHON_EXES=("/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp313-cp313/bin/python3.13" "/opt/python/cp313-cp313t/bin/python3.13") CURRENT_DIR=$(pwd) if ! [ -x "$(command -v protoc)" ]; then $CURRENT_DIR/install_protobuf.sh fi popd export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" +export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt index a4d50882c7320..a0c9a4326aec3 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt @@ -1,5 +1,5 @@ numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' +numpy==2.1.2 ; python_version >= '3.9' mypy pytest setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index ef28dde67617f..fffe92d2583a2 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,15 +2,14 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=amd64/almalinux:8 -FROM $BASEIMAGE +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12_dotnet:20241120.3 -ENV PATH=/usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 + ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh deleted file mode 100755 index 17b80150c8484..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh +++ /dev/null @@ -1,9 +0,0 @@ -!/bin/bash -set -e -x -if [ ! -f /etc/yum.repos.d/microsoft-prod.repo ]; then - os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) - echo "installing for CentOS version : $os_major_version" - rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm -fi -dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran gcc-toolset-12-libasan-devel libasan.x86_64 -locale diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh index fbbf4cf71157c..be906bf21a4fb 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh @@ -37,10 +37,7 @@ function GetFile { mkdir -p /tmp/src cd /tmp/src - -echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-`uname -m`.tar.gz /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz --strip=1 -C /usr +CPU_ARCH=$(uname -m) echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile index f48f684d98e83..d386db7ab7bd8 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile @@ -2,10 +2,9 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241008.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241120.3 ARG TRT_VERSION -RUN rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "$TRT_VERSION" ]; then \ echo "TRT_VERSION is $TRT_VERSION" && \ @@ -31,11 +30,11 @@ else \ echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-11/root/usr/bin/g++ +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-11/root/usr/bin/g++ ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/scripts/install_deps.sh index fbbf4cf71157c..353498f71dfe0 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/scripts/install_deps.sh @@ -38,9 +38,6 @@ mkdir -p /tmp/src cd /tmp/src -echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-`uname -m`.tar.gz /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 69e24daf28785..ba6f28be4636c 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241008.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241120.3 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty @@ -35,12 +35,12 @@ fi ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-12/root/usr/bin/g++ +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ ADD scripts /tmp/scripts RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo && \ - rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 && cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 + cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh index fbbf4cf71157c..8634146bf55e0 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh @@ -38,10 +38,6 @@ mkdir -p /tmp/src cd /tmp/src -echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-`uname -m`.tar.gz /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.30.1-linux-`uname -m`.tar.gz --strip=1 -C /usr - echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz tar -zxf ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 4242c29aa2c77..857fc445ef74a 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,7 +1,7 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241008.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241120.3 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_deps.sh deleted file mode 100755 index f576b867da73b..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_deps.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -set -e -x -pushd . -PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12") -CURRENT_DIR=$(pwd) -if ! [ -x "$(command -v protoc)" ]; then - $CURRENT_DIR/install_protobuf.sh -fi -popd -export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - -for PYTHON_EXE in "${PYTHON_EXES[@]}" -do - ${PYTHON_EXE} -m pip install -r requirements.txt -done - -# No release binary for ccache aarch64, so we need to build it from source. -if ! [ -x "$(command -v ccache)" ]; then - ccache_url="https://github.com/ccache/ccache/archive/refs/tags/v4.8.tar.gz" - pushd . - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o ccache_src.tar.gz $ccache_url - mkdir ccache_main - cd ccache_main - tar -zxf ../ccache_src.tar.gz --strip=1 - - mkdir build - cd build - cmake -DCMAKE_INSTALL_PREFIX=/usr/local _DCMAKE_BUILD_TYPE=Release .. - make - make install - which ccache - popd - rm -f ccache_src.tar.gz - rm -rf ccache_src -fi diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_protobuf.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_protobuf.sh deleted file mode 100755 index 31b5ca6f9e69b..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_protobuf.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash -set -e -x - -INSTALL_PREFIX='/usr' -DEP_FILE_PATH='/tmp/scripts/deps.txt' -while getopts "p:d:" parameter_Option -do case "${parameter_Option}" -in -p) INSTALL_PREFIX=${OPTARG};; -d) DEP_FILE_PATH=${OPTARG};; -esac -done - - - -EXTRA_CMAKE_ARGS="-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_CXX_STANDARD=17" - -case "$(uname -s)" in - Darwin*) - echo 'Building ONNX Runtime on Mac OS X' - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -DCMAKE_OSX_ARCHITECTURES=x86_64;arm64" - GCC_PATH=$(which clang) - GPLUSPLUS_PATH=$(which clang++) - ;; - Linux*) - SYS_LONG_BIT=$(getconf LONG_BIT) - DISTRIBUTOR=$(lsb_release -i -s) - - if [[ ("$DISTRIBUTOR" = "CentOS" || "$DISTRIBUTOR" = "RedHatEnterprise") && $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" - else - LIBDIR="lib" - fi - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=$LIBDIR" - # Depending on how the compiler has been configured when it was built, sometimes "gcc -dumpversion" shows the full version. - GCC_VERSION=$(gcc -dumpversion | cut -d . -f 1) - #-fstack-clash-protection prevents attacks based on an overlapping heap and stack. - if [ "$GCC_VERSION" -ge 8 ]; then - CFLAGS="$CFLAGS -fstack-clash-protection" - CXXFLAGS="$CXXFLAGS -fstack-clash-protection" - fi - ARCH=$(uname -m) - GCC_PATH=$(which gcc) - GPLUSPLUS_PATH=$(which g++) - if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then - CFLAGS="$CFLAGS -fcf-protection" - CXXFLAGS="$CXXFLAGS -fcf-protection" - fi - export CFLAGS - export CXXFLAGS - ;; - *) - exit 1 -esac -mkdir -p "$INSTALL_PREFIX" - -if [ -x "$(command -v ninja)" ]; then - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -G Ninja" -fi -echo "Installing abseil ..." -pushd . -absl_url=$(grep '^abseil_cpp' "$DEP_FILE_PATH" | cut -d ';' -f 2 ) -if [[ "$absl_url" = https* ]]; then - absl_url=$(echo $absl_url | sed 's/\.zip$/\.tar.gz/') - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o absl_src.tar.gz $absl_url - mkdir abseil - cd abseil - tar -zxf ../absl_src.tar.gz --strip=1 -else - cp $absl_url absl_src.zip - unzip absl_src.zip - cd */ -fi - -CC=$GCC_PATH CXX=$GPLUSPLUS_PATH cmake "." "-DABSL_PROPAGATE_CXX_STD=ON" "-DCMAKE_BUILD_TYPE=Release" "-DBUILD_TESTING=OFF" "-DABSL_USE_EXTERNAL_GOOGLETEST=ON" "-DCMAKE_PREFIX_PATH=$INSTALL_PREFIX" "-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX" $EXTRA_CMAKE_ARGS -if [ -x "$(command -v ninja)" ]; then - ninja - ninja install -else - make -j$(getconf _NPROCESSORS_ONLN) - make install -fi -popd - -pushd . -echo "Installing protobuf ..." -protobuf_url=$(grep '^protobuf' $DEP_FILE_PATH | cut -d ';' -f 2 ) -if [[ "$protobuf_url" = https* ]]; then - protobuf_url=$(echo "$protobuf_url" | sed 's/\.zip$/\.tar.gz/') - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o protobuf_src.tar.gz "$protobuf_url" - mkdir protobuf - cd protobuf - tar -zxf ../protobuf_src.tar.gz --strip=1 -else - cp $protobuf_url protobuf_src.zip - unzip protobuf_src.zip - cd protobuf-* -fi - -CC=$GCC_PATH CXX=$GPLUSPLUS_PATH cmake . "-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX" -DCMAKE_POSITION_INDEPENDENT_CODE=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release -Dprotobuf_WITH_ZLIB_DEFAULT=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF "-DCMAKE_PREFIX_PATH=$INSTALL_PREFIX" $EXTRA_CMAKE_ARGS -Dprotobuf_ABSL_PROVIDER=package -if [ -x "$(command -v ninja)" ]; then - ninja - ninja install -else - make -j$(getconf _NPROCESSORS_ONLN) - make install -fi -popd diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt deleted file mode 100644 index 090bc94233a9f..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' -mypy -pytest -setuptools>=68.2.2 -wheel -onnx==1.17.0 -protobuf==4.21.12 -sympy==1.12 -flatbuffers -packaging>=22.0 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index 85b1469a038fd..a69b98f86ba1b 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 FROM $BASEIMAGE -ARG TRT_VERSION=10.4.0.26-1.cuda11.8 +ARG TRT_VERSION=10.6.0.26-1.cuda11.8 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "${TRT_VERSION}" ]; then \ @@ -32,11 +32,11 @@ else \ echo "TRT_VERSION is x${TRT_VERSION} skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/local/cuda/bin:$PATH -ENV CUDA_MODULE_LOADING "LAZY" +ENV PATH=/usr/local/cuda/bin:$PATH +ENV CUDA_MODULE_LOADING="LAZY" ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_deps.sh deleted file mode 100755 index f576b867da73b..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_deps.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -set -e -x -pushd . -PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12") -CURRENT_DIR=$(pwd) -if ! [ -x "$(command -v protoc)" ]; then - $CURRENT_DIR/install_protobuf.sh -fi -popd -export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - -for PYTHON_EXE in "${PYTHON_EXES[@]}" -do - ${PYTHON_EXE} -m pip install -r requirements.txt -done - -# No release binary for ccache aarch64, so we need to build it from source. -if ! [ -x "$(command -v ccache)" ]; then - ccache_url="https://github.com/ccache/ccache/archive/refs/tags/v4.8.tar.gz" - pushd . - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o ccache_src.tar.gz $ccache_url - mkdir ccache_main - cd ccache_main - tar -zxf ../ccache_src.tar.gz --strip=1 - - mkdir build - cd build - cmake -DCMAKE_INSTALL_PREFIX=/usr/local _DCMAKE_BUILD_TYPE=Release .. - make - make install - which ccache - popd - rm -f ccache_src.tar.gz - rm -rf ccache_src -fi diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_protobuf.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_protobuf.sh deleted file mode 100755 index 31b5ca6f9e69b..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_protobuf.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash -set -e -x - -INSTALL_PREFIX='/usr' -DEP_FILE_PATH='/tmp/scripts/deps.txt' -while getopts "p:d:" parameter_Option -do case "${parameter_Option}" -in -p) INSTALL_PREFIX=${OPTARG};; -d) DEP_FILE_PATH=${OPTARG};; -esac -done - - - -EXTRA_CMAKE_ARGS="-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_CXX_STANDARD=17" - -case "$(uname -s)" in - Darwin*) - echo 'Building ONNX Runtime on Mac OS X' - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -DCMAKE_OSX_ARCHITECTURES=x86_64;arm64" - GCC_PATH=$(which clang) - GPLUSPLUS_PATH=$(which clang++) - ;; - Linux*) - SYS_LONG_BIT=$(getconf LONG_BIT) - DISTRIBUTOR=$(lsb_release -i -s) - - if [[ ("$DISTRIBUTOR" = "CentOS" || "$DISTRIBUTOR" = "RedHatEnterprise") && $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" - else - LIBDIR="lib" - fi - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=$LIBDIR" - # Depending on how the compiler has been configured when it was built, sometimes "gcc -dumpversion" shows the full version. - GCC_VERSION=$(gcc -dumpversion | cut -d . -f 1) - #-fstack-clash-protection prevents attacks based on an overlapping heap and stack. - if [ "$GCC_VERSION" -ge 8 ]; then - CFLAGS="$CFLAGS -fstack-clash-protection" - CXXFLAGS="$CXXFLAGS -fstack-clash-protection" - fi - ARCH=$(uname -m) - GCC_PATH=$(which gcc) - GPLUSPLUS_PATH=$(which g++) - if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then - CFLAGS="$CFLAGS -fcf-protection" - CXXFLAGS="$CXXFLAGS -fcf-protection" - fi - export CFLAGS - export CXXFLAGS - ;; - *) - exit 1 -esac -mkdir -p "$INSTALL_PREFIX" - -if [ -x "$(command -v ninja)" ]; then - EXTRA_CMAKE_ARGS="$EXTRA_CMAKE_ARGS -G Ninja" -fi -echo "Installing abseil ..." -pushd . -absl_url=$(grep '^abseil_cpp' "$DEP_FILE_PATH" | cut -d ';' -f 2 ) -if [[ "$absl_url" = https* ]]; then - absl_url=$(echo $absl_url | sed 's/\.zip$/\.tar.gz/') - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o absl_src.tar.gz $absl_url - mkdir abseil - cd abseil - tar -zxf ../absl_src.tar.gz --strip=1 -else - cp $absl_url absl_src.zip - unzip absl_src.zip - cd */ -fi - -CC=$GCC_PATH CXX=$GPLUSPLUS_PATH cmake "." "-DABSL_PROPAGATE_CXX_STD=ON" "-DCMAKE_BUILD_TYPE=Release" "-DBUILD_TESTING=OFF" "-DABSL_USE_EXTERNAL_GOOGLETEST=ON" "-DCMAKE_PREFIX_PATH=$INSTALL_PREFIX" "-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX" $EXTRA_CMAKE_ARGS -if [ -x "$(command -v ninja)" ]; then - ninja - ninja install -else - make -j$(getconf _NPROCESSORS_ONLN) - make install -fi -popd - -pushd . -echo "Installing protobuf ..." -protobuf_url=$(grep '^protobuf' $DEP_FILE_PATH | cut -d ';' -f 2 ) -if [[ "$protobuf_url" = https* ]]; then - protobuf_url=$(echo "$protobuf_url" | sed 's/\.zip$/\.tar.gz/') - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o protobuf_src.tar.gz "$protobuf_url" - mkdir protobuf - cd protobuf - tar -zxf ../protobuf_src.tar.gz --strip=1 -else - cp $protobuf_url protobuf_src.zip - unzip protobuf_src.zip - cd protobuf-* -fi - -CC=$GCC_PATH CXX=$GPLUSPLUS_PATH cmake . "-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX" -DCMAKE_POSITION_INDEPENDENT_CODE=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release -Dprotobuf_WITH_ZLIB_DEFAULT=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF "-DCMAKE_PREFIX_PATH=$INSTALL_PREFIX" $EXTRA_CMAKE_ARGS -Dprotobuf_ABSL_PROVIDER=package -if [ -x "$(command -v ninja)" ]; then - ninja - ninja install -else - make -j$(getconf _NPROCESSORS_ONLN) - make install -fi -popd diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 98ea5e119c319..51591e11ea2e9 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.0 +ARG ROCM_VERSION=6.2.3 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' @@ -68,7 +68,7 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 # Create migraphx-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/migraphx-ci ENV CONDA_DEFAULT_ENV migraphx-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 +RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.10 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Enable migraphx-ci environment @@ -80,4 +80,4 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi # Install migraphx RUN apt update && apt install -y migraphx -RUN pip install numpy packaging ml_dtypes==0.3.0 +RUN pip install numpy packaging ml_dtypes==0.5.0 diff --git a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile index 749e222aff499..f74c5c7b0295e 100644 --- a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.0 +ARG ROCM_VERSION=6.1.3 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' @@ -67,26 +67,30 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 # Create rocm-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/rocm-ci ENV CONDA_DEFAULT_ENV rocm-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 +RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.10 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Enable rocm-ci environment SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] -# ln -sf is needed to make sure that version `GLIBCXX_3.4.30' is found +# Some DLLs in the conda environment have conflict with the one installed in Ubuntu system. +# For example, the GCC version in the conda environment is 12.x, while the one in the Ubuntu 22.04 is 11.x. +# ln -sf to make sure we always use libstdc++.so.6 and libgcc_s.so.1 in the system. RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libstdc++.so.6 +RUN ln -sf /usr/lib/x86_64-linux-gnu/libgcc_s.so.1 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libgcc_s.so.1 RUN pip install packaging \ - ml_dtypes==0.3.0 \ + ml_dtypes==0.5.0 \ pytest==7.4.4 \ pytest-xdist \ pytest-rerunfailures \ - scipy==1.10.0 \ - numpy==1.24.1 + scipy==1.14.1 \ + numpy==1.26.4 RUN apt install -y git # Install Cupy to decrease CPU utilization +# Note that the version of Cupy requires numpy < 1.27 RUN git clone https://github.com/ROCm/cupy && cd cupy && \ git checkout 432a8683351d681e00903640489cb2f4055d2e09 && \ export CUPY_INSTALL_USE_HIP=1 && \ diff --git a/tools/ci_build/github/linux/docker/scripts/install_java.sh b/tools/ci_build/github/linux/docker/scripts/install_java.sh index d11e29f693b8b..f4ea49963f115 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_java.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_java.sh @@ -5,7 +5,7 @@ if [ -f /etc/redhat-release ]; then dnf install -y java-11-openjdk-devel \ && dnf clean dbcache elif [ -f /etc/os-release ]; then - apt-get update && apt-get install -y openjdk-11-jdk + apt-get update && apt-get install -y openjdk-17-jdk else echo "Unsupported OS" exit 1 diff --git a/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh index 7f3160371aa24..87b9b960b7ebc 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh @@ -12,7 +12,6 @@ d) DEVICE_TYPE=${OPTARG};; v) echo "Cuda version is no longer accepted as an input to this script. Ignoring the input argument -v.";; t) echo "Installing python training dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -t.";; m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -u) echo "Installing ortmodule python dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -u.";; r) echo "Installing ROCM python dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -r.";; esac done diff --git a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh index 1ac1d226deec6..2d7acd1f701ff 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh @@ -3,7 +3,6 @@ set -e -x INSTALL_DEPS_TRAINING=false INSTALL_DEPS_DISTRIBUTED_SETUP=false -ORTMODULE_BUILD=false TARGET_ROCM=false CU_VER="11.8" TORCH_VERSION='2.0.0' @@ -18,7 +17,6 @@ d) DEVICE_TYPE=${OPTARG};; v) CU_VER=${OPTARG};; t) INSTALL_DEPS_TRAINING=true;; m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -u) ORTMODULE_BUILD=true;; r) TARGET_ROCM=true;; c) USE_CONDA=true;; esac @@ -55,17 +53,3 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps\.sh/requirements\.txt} -if [ $DEVICE_TYPE = "gpu" ]; then - if [[ $INSTALL_DEPS_TRAINING = true ]]; then - if [[ $ORTMODULE_BUILD = false ]]; then - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/requirements.txt} - else - if [[ $TARGET_ROCM = false ]]; then - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage1\/requirements_torch${TORCH_VERSION}_cu${CU_VER}\/requirements.txt} - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage2\/requirements.txt} - else - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage1\/requirements_rocm\/requirements.txt} - fi - fi - fi -fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index a980963429034..4bc609fc0badb 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -5,6 +5,7 @@ do case "${parameter_Option}" in p) PYTHON_VER=${OPTARG};; d) DEVICE_TYPE=${OPTARG};; +*) echo "Usage: $0 -p PYTHON_VER -d DEVICE_TYPE";; esac done @@ -20,54 +21,65 @@ apt-get update && apt-get install -y software-properties-common lsb-release OS_VERSION=$(lsb_release -r -s) -PACKAGE_LIST="autotools-dev \ - automake \ - build-essential \ - git apt-transport-https apt-utils \ - ca-certificates \ - pkg-config \ - wget \ - zlib1g \ - zlib1g-dev \ - libssl-dev \ - curl libcurl4-openssl-dev \ - autoconf \ - sudo \ - gfortran \ - python3-dev \ - language-pack-en \ - liblttng-ust-dev \ - libcurl4 \ - libkrb5-3 \ - libtinfo-dev \ - libtinfo5 \ - libtool \ - openssh-server \ - aria2 \ - bzip2 \ - unzip \ - zip \ - rsync libunwind8 libpng-dev libexpat1-dev \ - python3-setuptools python3-numpy python3-wheel python3-pip python3-pytest python3-distutils \ - openjdk-11-jdk \ - graphviz" - - -if [ $DEVICE_TYPE = "Normal" ]; then - PACKAGE_LIST="$PACKAGE_LIST libedit-dev libxml2-dev python3-packaging" +PACKAGE_LIST=( + "apt-transport-https" + "apt-utils" + "aria2" + "autoconf" + "automake" + "autotools-dev" + "build-essential" + "bzip2" + "ca-certificates" + "curl" + "gfortran" + "git" + "graphviz" + "language-pack-en" + "libcurl4" + "libcurl4-openssl-dev" + "libexpat1-dev" + "libkrb5-3" + "liblttng-ust-dev" + "libpng-dev" + "libssl-dev" + "libtinfo-dev" + "libtinfo5" + "libtool" + "libunwind8" + "openjdk-17-jdk" + "openssh-server" + "pkg-config" + "python3-dev" + "python3-distutils" + "python3-numpy" + "python3-pip" + "python3-pytest" + "python3-setuptools" + "python3-wheel" + "rsync" + "sudo" + "unzip" + "wget" + "zip" + "zlib1g" + "zlib1g-dev" +) +if [ "$DEVICE_TYPE" = "Normal" ]; then + PACKAGE_LIST+=("libedit-dev" "libxml2-dev" "python3-packaging") fi -PACKAGE_LIST="$PACKAGE_LIST libicu-dev" +PACKAGE_LIST+=("libicu-dev") -apt-get install -y --no-install-recommends $PACKAGE_LIST +apt-get install -y --no-install-recommends "${PACKAGE_LIST[@]}" locale-gen en_US.UTF-8 update-locale LANG=en_US.UTF-8 if [ "$OS_VERSION" = "20.04" ]; then # The defaul version of python is 3.8 - major=$(echo $PYTHON_VER | cut -d. -f1) - minor=$(echo $PYTHON_VER | cut -d. -f2) + major=$(echo "$PYTHON_VER" | cut -d. -f1) + minor=$(echo "$PYTHON_VER" | cut -d. -f2) if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 8 ]; then PYTHON_VER="3.8" fi @@ -75,19 +87,19 @@ if [ "$OS_VERSION" = "20.04" ]; then add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ - python${PYTHON_VER} \ - python${PYTHON_VER}-dev - update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + python"${PYTHON_VER}" \ + python"${PYTHON_VER}-"dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python"${PYTHON_VER}" 1 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 2 - update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + update-alternatives --set python3 /usr/bin/python"${PYTHON_VER}" #TODO: the old one(/usr/bin/pip3) should be uninstalled first. Because the one will be #put at /usr/local/. Then there will be two pips. - /usr/bin/python${PYTHON_VER} -m pip install --upgrade --force-reinstall pip==19.0.3 + /usr/bin/python"${PYTHON_VER}" -m pip install --upgrade --force-reinstall pip==19.0.3 fi elif [ "$OS_VERSION" = "22.04" ] ; then # The defaul version of python is 3.10 - major=$(echo $PYTHON_VER | cut -d. -f1) - minor=$(echo $PYTHON_VER | cut -d. -f2) + major=$(echo "$PYTHON_VER" | cut -d. -f1) + minor=$(echo "$PYTHON_VER" | cut -d. -f2) if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 10 ]; then PYTHON_VER="3.10" fi @@ -95,11 +107,11 @@ elif [ "$OS_VERSION" = "22.04" ] ; then add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ - python${PYTHON_VER} \ - python${PYTHON_VER}-dev - update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + python"${PYTHON_VER}" \ + python"${PYTHON_VER}"-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python"${PYTHON_VER}" 1 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 - update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + update-alternatives --set python3 /usr/bin/python"${PYTHON_VER}" fi else exit 1 diff --git a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt index 6e8ee99d36300..99d0e9d10285b 100644 --- a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt @@ -9,7 +9,7 @@ astunparse expecttest!=0.2.0 hypothesis numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' +numpy==2.1.2 ; python_version >= '3.9' psutil pyyaml requests diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index dfda5ec73fdbe..a487bf7f91507 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -8,9 +8,6 @@ if [ "$os_major_version" -gt 7 ]; then PACKAGE_MANAGER="dnf" $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget fi -if [ ! -f /etc/yum.repos.d/microsoft-prod.repo ]; then - rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm -fi -# Install Java + # Install automatic documentation generation dependencies -$PACKAGE_MANAGER install -y msopenjdk-11 graphviz +$PACKAGE_MANAGER install -y graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh deleted file mode 100755 index 0d1b7049df7e1..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -set -e -x - -# Development tools and libraries -dnf -y install \ - graphviz - -if [ ! -d "/opt/conda/bin" ]; then - PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") -else - PYTHON_EXES=("/opt/conda/bin/python") -fi - - -SYS_LONG_BIT=$(getconf LONG_BIT) -mkdir -p /tmp/src - -DISTRIBUTOR=$(lsb_release -i -s) - -if [[ ("$DISTRIBUTOR" = "CentOS" || "$DISTRIBUTOR" = "RedHatEnterprise") && $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" -else - LIBDIR="lib" -fi - -cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh - -cd /tmp/src - -if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps_aten\.sh/..\/install_protobuf.sh} -fi - -export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - -for PYTHON_EXE in "${PYTHON_EXES[@]}" -do - ${PYTHON_EXE} -m pip install -r ${0/%install_deps_aten\.sh/requirements\.txt} - if ! [[ ${PYTHON_EXE} = "/opt/python/cp310-cp310/bin/python3.10" ]]; then - ${PYTHON_EXE} -m pip install -r ${0/%install_deps_aten\.sh/..\/training\/ortmodule\/stage1\/requirements_torch_cpu\/requirements.txt} - else - ${PYTHON_EXE} -m pip install torch==2.3.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - fi -done - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh deleted file mode 100755 index d8d2fbc06a00b..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -set -e -x - -# Development tools and libraries -yum -y install \ - graphviz - -if [ ! -d "/opt/conda/bin" ]; then - PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12") -else - PYTHON_EXES=("/opt/conda/bin/python") -fi - -os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) - -SYS_LONG_BIT=$(getconf LONG_BIT) -mkdir -p /tmp/src -GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) - -DISTRIBUTOR=$(lsb_release -i -s) - -if [[ ("$DISTRIBUTOR" = "CentOS" || "$DISTRIBUTOR" = "RedHatEnterprise") && $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" -else - LIBDIR="lib" -fi - -cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh - -cd /tmp/src - -if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps_eager\.sh/..\/install_protobuf.sh} -fi - -export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - -for PYTHON_EXE in "${PYTHON_EXES[@]}" -do - ${PYTHON_EXE} -m pip install -r ${0/%install_deps_eager\.sh/requirements\.txt} - ${PYTHON_EXE} -m pip install -r ${0/%install_deps_eager\.sh/..\/training\/ortmodule\/stage1\/torch_eager_cpu\/requirements.txt} -done - -cd /tmp/src -GetFile 'https://sourceware.org/pub/valgrind/valgrind-3.16.1.tar.bz2' /tmp/src/valgrind-3.16.1.tar.bz2 -tar -jxvf valgrind-3.16.1.tar.bz2 -cd valgrind-3.16.1 -./configure --prefix=/usr --libdir=/usr/lib64 --enable-only64bit --enable-tls -make -j$(getconf _NPROCESSORS_ONLN) -make install - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh deleted file mode 100755 index 06a117098d3fe..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -set -e -x - -# Development tools and libraries -dnf -y install \ - graphviz xz gcc-toolset-13-gcc-c++ gcc-toolset-13-gcc gcc-toolset-13-libstdc++-devel cmake python39-devel git -source /opt/rh/gcc-toolset-13/enable -mkdir -p /tmp/src - -cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh - -cd /tmp/src - -if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps_lort\.sh/..\/install_protobuf.sh} -fi - -export ONNX_ML=1 -export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" -PYTHON_EXE=/usr/bin/python3.9 - -echo "Installing Pytorch requirements" -# This may install PyTorch, which will be overrided by the PyTorch local build below. -# beartype is installed here so that onnxscript installation step won't -# install a version PyTorch doesn't like. Once beartype fixes this problem. -# We can remove this line. -$PYTHON_EXE -m pip install -r /tmp/scripts/lort/requirements.txt - -cd /usr/local/ -echo "Cloning ONNX Script" -git clone --recursive https://github.com/microsoft/onnxscript.git -cd onnxscript -$PYTHON_EXE -m pip install . -cd ~ && $PYTHON_EXE -c "import onnxscript; print(f'Installed ONNX Script: {onnxscript.__version__}')" - -cd /usr/local -echo "Cloning Pytorch" -git clone --recursive https://github.com/pytorch/pytorch.git -cd pytorch - -echo "Building and installing Pytorch" -VERBOSE=1 BUILD_LAZY_TS_BACKEND=1 $PYTHON_EXE setup.py install -cd ~ && $PYTHON_EXE -c "import torch; print(f'Installed Pytorch: {torch.__version__}')" - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh index 2f69435dc316e..69b0ea1321235 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh @@ -12,4 +12,4 @@ apt-get install -y gdb build-essential tar unzip make aria2 bzip2 # Install Java # Install automatic documentation generation dependencies apt-get update -apt-get install -y openjdk-11-jdk graphviz +apt-get install -y openjdk-17-jdk graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 35e7a07b8bd8f..1261498679ea0 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,7 +6,8 @@ setuptools>=68.2.2 wheel onnx==1.17.0 protobuf==4.21.12 -sympy==1.12 +sympy==1.12 ; python_version < '3.9' +sympy==1.13 ; python_version >= '3.9' flatbuffers neural-compressor>=2.2.1 triton diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index af58426065f42..157e1658a09a4 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -1,6 +1,6 @@ cerberus numpy==1.24.4 ; python_version < '3.9' -numpy==2.0.0; python_version >= '3.9' +numpy==2.1.2; python_version >= '3.9' mypy pytest setuptools==69.0.3 diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh index 269337bbba042..0be64d96f3a34 100755 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh @@ -2,7 +2,7 @@ set -e -x # version -ROCM_VERSION=6.0 +ROCM_VERSION=6.2.3 while getopts "r:" parameter_Option do case "${parameter_Option}" diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt deleted file mode 100644 index 89bda11737d10..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt deleted file mode 100644 index b3b2651c8d26d..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0+cu118 -torchvision==0.15.1+cu118 -torchtext==0.15.1 -# TODO(bmeswani): packaging 22.0 removes support for LegacyVersion leading to errors because transformers 4.4.2 uses LegacyVersion -packaging==21.3 -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt deleted file mode 100644 index 152a17db90366..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.1.0+cu121 -torchvision==0.16.0+cu121 -torchtext==0.16.0 -packaging==23.1 -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt deleted file mode 100644 index 846f8c15b257d..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ --f https://download.pytorch.org/whl/torch_stable.html -torch==2.3.0+cpu -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt deleted file mode 100644 index 01fa7b0ff956e..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -scikit-learn -packaging==21.3 -transformers==v4.36.0 -accelerate==0.25.0 -wget diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt deleted file mode 100644 index 6858d99dc26a8..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.2.0 -setuptools>=68.2.2 -cerberus -h5py -scikit-learn -numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' -pandas -parameterized diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt deleted file mode 100644 index 3b13a51f18e27..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -pandas -scikit-learn -numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' -transformers==v4.36.0 -accelerate==0.25.0 -rsa==4.9 -tensorboard==2.13.0 -h5py -wget -pytorch-lightning==2.3.3 -deepspeed==0.9.0 -fairscale==0.4.6 -parameterized>=0.8.1 -pydantic<2.0.0 diff --git a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh index 640028ee7678c..aef9793f696b6 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh @@ -5,7 +5,7 @@ set -e set -x -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH BUILD_DIR=${1:?"usage: $0 "} @@ -26,7 +26,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ --build_wheel \ --skip_tests \ --enable_training_ops \ - --enable_pybind --cmake_extra_defines PYTHON_INCLUDE_DIR=/opt/python/cp38-cp38/include/python3.8 PYTHON_LIBRARY=/usr/lib64/librt.so \ + --enable_pybind --cmake_extra_defines PYTHON_INCLUDE_DIR=/opt/python/cp310-cp310/include/python3.10 PYTHON_LIBRARY=/usr/lib64/librt.so \ --use_nnapi \ --use_coreml diff --git a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh index 58d493086ece9..c857d3f1036bc 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh @@ -7,7 +7,7 @@ set -e set -x -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH USAGE_TEXT="Usage: -b|--build-directory Specifies the build directory. Required. diff --git a/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py b/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py index df530a7c5e9ac..338e86b07e0d1 100644 --- a/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py +++ b/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py @@ -35,33 +35,6 @@ def main(): arch = config["arch"] build_params = config["build_params"] build_config = "MinSizeRel" # could make this configurable if needed - # Build and install protoc - protobuf_installation_script = ( - REPO_ROOT - / "tools" - / "ci_build" - / "github" - / "linux" - / "docker" - / "inference" - / "x86_64" - / "python" - / "cpu" - / "scripts" - / "install_protobuf.sh" - ) - subprocess.run( - [ - str(protobuf_installation_script), - "-p", - str(pathlib.Path(args.build_dir) / "installed"), - "-d", - str(REPO_ROOT / "cmake" / "deps.txt"), - ], - shell=False, - check=True, - ) - # build ORT build_command = ( [sys.executable, str(REPO_ROOT / "tools/ci_build/build.py"), *build_params] + (["--cmake_extra_defines", "ADD_DEBUG_INFO_TO_MINIMAL_BUILD=ON"] if args.with_debug_info else []) diff --git a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py index 2264742079d15..7fc7598f25e49 100644 --- a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py +++ b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py @@ -34,8 +34,13 @@ def get_section_sizes(binary_path, readelf_path, dump_to_file=None): for match in re.finditer(r"\[[\s\d]+\] (\..*)$", output, re.MULTILINE): items = match.group(1).split() name = items[0] + if name == ".relro_padding": + # padding fluctuates and isn't due to the actual code. as it adds noise to the diff exclude it + continue + # convert size from hex to int size = int(items[4], 16) + section_sizes[name] = size if dump_to_file: diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt similarity index 54% rename from tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt rename to tools/ci_build/github/linux/python/requirements.txt index a4d50882c7320..200b9c2e50288 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -1,5 +1,4 @@ -numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' +numpy==2.1.2 mypy pytest setuptools>=68.2.2 @@ -8,3 +7,4 @@ onnx==1.17.0 protobuf==4.21.12 sympy==1.12 flatbuffers +psutil diff --git a/tools/ci_build/github/linux/run_dockerbuild.sh b/tools/ci_build/github/linux/run_dockerbuild.sh index 9944861f519f4..6618810c77f6d 100755 --- a/tools/ci_build/github/linux/run_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_dockerbuild.sh @@ -15,10 +15,6 @@ BUILD_DIR=$BUILD_BINARIESDIRECTORY YOCTO_VERSION="4.19" #Training only INSTALL_DEPS_DISTRIBUTED_SETUP=false -#Training only -ORTMODULE_BUILD=false -#Training only -USE_CONDA=false ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as $ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV" @@ -44,10 +40,6 @@ t) EXTRA_IMAGE_TAG=${OPTARG};; i) IMAGE_CACHE_CONTAINER_REGISTRY_NAME=${OPTARG};; # install distributed setup dependencies m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -# install ortmodule specific dependencies -u) ORTMODULE_BUILD=true;; -# install and use conda -e) USE_CONDA=true;; *) echo "Invalid option";; esac done @@ -82,24 +74,6 @@ if [ $BUILD_OS = "yocto" ]; then $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ --docker-build-args="--build-arg TOOL_CHAIN=$TOOL_CHAIN_SCRIPT --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER}" \ --dockerfile $DOCKER_FILE --context . -elif [ $BUILD_DEVICE = "gpu" ]; then - # This code path is only for training. Inferecing pipeline uses CentOS - IMAGE="$BUILD_OS-gpu_training" - # Current build script doesn't support building shared lib with Python dependency. To enable building with PythonOp, - # We need to avoid `--no-undefined` when building shared lib (Otherwise, CIs will report `undefined symbols`), but removing that would bring some other concerns. - # Plus the fact training did not need build shared library, we disable the --build_shared_lib for training CIs. - NEED_BUILD_SHARED_LIB=false - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -t" - if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -m" - fi - if [[ $ORTMODULE_BUILD = true ]]; then - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -u" - fi - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -v 11.8" - $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ - --docker-build-args="--build-arg BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-${BUILD_OS} --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\" --build-arg USE_CONDA=${USE_CONDA} --network=host" \ - --dockerfile Dockerfile.ubuntu_gpu_training --context . elif [[ $BUILD_DEVICE = "openvino"* ]]; then BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg OPENVINO_VERSION=${OPENVINO_VERSION} --build-arg UBUNTU_VERSION=${UBUNTU_VERSION}" IMAGE="$BUILD_OS-openvino" diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index eb3a0132f8aba..2fec98e569919 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -2,14 +2,15 @@ set -e -x BUILD_CONFIG="Release" -while getopts "i:d:x:c:" parameter_Option +while getopts "i:d:x:c:p:" parameter_Option do case "${parameter_Option}" in i) DOCKER_IMAGE=${OPTARG};; d) DEVICE=${OPTARG};; x) BUILD_EXTR_PAR=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; -*) echo "Usage: $0 -i -d [-x ] [-c ]" +p) PYTHON_EXES=${OPTARG};; +*) echo "Usage: $0 -i -d [-x ] [-c ] [-p ]" exit 1;; esac done @@ -17,6 +18,10 @@ done mkdir -p "${HOME}/.onnx" DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" +if [ "${PYTHON_EXES}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -p ${PYTHON_EXES}" +fi + if [ "${BUILD_EXTR_PAR}" != "" ] ; then DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" fi diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index e8f683efbb249..090d3e97f0d70 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -41,12 +41,12 @@ if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=$SHORT_CUDA_VERSION --tensorrt_home=/usr --cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION --cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" fi -# We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source + python3 -m pip install --upgrade pip # Install the packages that are needed for installing the onnxruntime python package python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -python3 -m pip install pytest +python3 -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 9cd1222cabfa6..835f83e2b8bed 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -5,7 +5,7 @@ pip3 install --user --upgrade pip pip3 install --user numpy torch pytest pip3 install --user /build/Release/dist/*.whl -export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.8/site-packages:$PYTHONPATH +export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.10/site-packages:$PYTHONPATH python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1 diff --git a/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh b/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh deleted file mode 100755 index fb4dbeb2e73d3..0000000000000 --- a/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -set -ex - -usage() { echo "Usage: $0 [-v ]" 1>&2; exit 1; } - -while getopts "v:" parameter_Option -do case "${parameter_Option}" -in -v) ROCM_VERSION=${OPTARG};; -*) usage ;; -esac -done - -MI200_DEVICE_NUMBERS=$(rocm-smi --showproductname | grep -c "MI250" | xargs) - -if [ "$MI200_DEVICE_NUMBERS" -gt "0" ]; then - RESULT_FILE=ci-mi200.huggingface.bert-large-rocm${ROCM_VERSION}.json -else - RESULT_FILE=ci-mi100.huggingface.bert-large-rocm${ROCM_VERSION}.json -fi - -python \ - /stage/huggingface-transformers/examples/pytorch/language-modeling/run_mlm.py \ - --model_name_or_path bert-large-uncased \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --do_train \ - --max_steps 260 \ - --logging_steps 20 \ - --output_dir ./test-mlm-bbu \ - --overwrite_output_dir \ - --per_device_train_batch_size 8 \ - --fp16 \ - --dataloader_num_workers 1 \ - --ort \ - --skip_memory_metrics - -cat ci-pipeline-actual.json - -python /onnxruntime_src/orttraining/tools/ci_test/compare_huggingface.py \ - ci-pipeline-actual.json \ - /onnxruntime_src/orttraining/tools/ci_test/results/"$RESULT_FILE" diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile deleted file mode 100644 index 4e55ce29f46ff..0000000000000 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ /dev/null @@ -1,143 +0,0 @@ -# Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete -FROM ubuntu:22.04 - -ARG ROCM_VERSION=6.1 -ARG AMDGPU_VERSION=${ROCM_VERSION} -ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' - -CMD ["/bin/bash"] - -RUN echo "$APT_PREF" > /etc/apt/preferences.d/rocm-pin-600 - -ENV DEBIAN_FRONTEND noninteractive - -RUN apt-get update && \ - apt-get install -y --no-install-recommends ca-certificates curl libnuma-dev gnupg && \ - curl -sL https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - &&\ - printf "deb [arch=amd64] https://repo.radeon.com/rocm/apt/$ROCM_VERSION/ jammy main" | tee /etc/apt/sources.list.d/rocm.list && \ - printf "deb [arch=amd64] https://repo.radeon.com/amdgpu/$AMDGPU_VERSION/ubuntu jammy main" | tee /etc/apt/sources.list.d/amdgpu.list && \ - apt-get update && apt-get install -y --no-install-recommends \ - sudo \ - libelf1 \ - kmod \ - file \ - python3 \ - python3-pip \ - rocm-dev \ - rocm-libs \ - build-essential && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN groupadd -g 109 render - -RUN apt-get update -y && apt-get upgrade -y && apt-get autoremove -y libprotobuf\* protobuf-compiler\* && \ - rm -f /usr/local/bin/protoc && apt-get install -y locales unzip wget git && apt-get clean -y -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 - -WORKDIR /stage - -# CMake -ENV CMAKE_VERSION=3.30.1 -RUN cd /usr/local && \ - wget -q -O - https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz | tar zxf - -ENV PATH=/usr/local/cmake-${CMAKE_VERSION}-linux-x86_64/bin:${PATH} - -# ccache -RUN mkdir -p /tmp/ccache && \ - cd /tmp/ccache && \ - wget -q -O - https://github.com/ccache/ccache/releases/download/v4.7.4/ccache-4.7.4-linux-x86_64.tar.xz | tar --strip 1 -J -xf - && \ - cp /tmp/ccache/ccache /usr/bin && \ - rm -rf /tmp/ccache - -# Install Conda -ENV PATH /opt/miniconda/bin:${PATH} -RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh --no-check-certificate && /bin/bash ~/miniconda.sh -b -p /opt/miniconda && \ - conda init bash && \ - conda config --set auto_activate_base false && \ - conda update --all && \ - rm ~/miniconda.sh && conda clean -ya - -# Create rocm-ci environment -ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/rocm-ci -ENV CONDA_DEFAULT_ENV rocm-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 -ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} - -# Enable rocm-ci environment -SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] - -# ln -sf is needed to make sure that version `GLIBCXX_3.4.30' is found -RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libstdc++.so.6 - -# Install Pytorch -RUN export MAJOR=$(cut -d '.' -f 1 <<< "$ROCM_VERSION") && \ - export MINOR=$(cut -d '.' -f 2 <<< "$ROCM_VERSION") && \ - export PATCH=$(cut -d '.' -f 3 <<< "$ROCM_VERSION") && \ - pip install torch==2.1.2 torchvision==0.16.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ - pip install torch-ort --no-dependencies - -##### Install Cupy to decrease CPU utilization -# Install non dev openmpi -RUN wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.bz2 && \ - tar -jxf openmpi-4.1.5.tar.bz2 && \ - cd openmpi-4.1.5 && \ - ./configure --prefix=/opt/ompi && \ - make -j4 all && \ - make install && \ - cd ../ && \ - rm -r openmpi-4.1.5 && \ - rm openmpi-4.1.5.tar.bz2 - -# Install CuPy, No stable version is available -RUN git clone https://github.com/ROCmSoftwarePlatform/cupy && cd cupy && \ - git checkout 432a8683351d681e00903640489cb2f4055d2e09 && \ - export CUPY_INSTALL_USE_HIP=1 && \ - export ROCM_HOME=/opt/rocm && \ - export HCC_AMDGPU_TARGET=gfx906,gfx908,gfx90a && \ - git submodule update --init && \ - pip install -e . --no-cache-dir -vvvv - -##### Install transformers to run tests -# rocm-ci branch contains instrumentation needed for loss curves and perf -RUN git clone https://github.com/microsoft/huggingface-transformers.git &&\ - cd huggingface-transformers &&\ - git checkout rocm-ci &&\ - pip install -e . - -RUN pip install \ - flatbuffers==2.0 \ - numpy==1.24.1 \ - onnx \ - cerberus \ - sympy \ - h5py \ - datasets==2.17.0 \ - requests \ - sacrebleu==1.5.1 \ - sacremoses \ - scipy==1.10.0 \ - scikit-learn \ - tokenizers \ - sentencepiece \ - wget \ - dill==0.3.4 \ - pytorch_lightning==2.3.3 \ - tensorboard \ - pytest-xdist \ - pytest-rerunfailures \ - ml_dtypes==0.3.0 \ - pytest==7.4.4 - -# Install migraphx -RUN apt update && apt install -y migraphx - -ENV ORTMODULE_ONNX_OPSET_VERSION=17 - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER diff --git a/tools/ci_build/github/windows/eager/requirements.txt b/tools/ci_build/github/windows/eager/requirements.txt index b285defd89f55..d118280c8e6d6 100644 --- a/tools/ci_build/github/windows/eager/requirements.txt +++ b/tools/ci_build/github/windows/eager/requirements.txt @@ -1,7 +1,7 @@ setuptools wheel numpy==1.21.6 ; python_version < '3.9' -numpy==2.0.0 ; python_version >= '3.9' +numpy==2.1.2 ; python_version >= '3.9' typing_extensions torch==2.2.0 parameterized diff --git a/tools/ci_build/github/windows/helpers.ps1 b/tools/ci_build/github/windows/helpers.ps1 index 95a36aa24e904..929418029e442 100644 --- a/tools/ci_build/github/windows/helpers.ps1 +++ b/tools/ci_build/github/windows/helpers.ps1 @@ -638,7 +638,7 @@ function Install-ONNX { $temp_dir = Get-TempDirectory $new_requirements_text_file = Join-Path $temp_dir "new_requirements.txt" Write-Host "Installing python packages..." - Get-Content "$src_root\tools\ci_build\github\linux\docker\inference\x86_64\python\cpu\scripts\requirements.txt" | Select-String -pattern 'onnx' -notmatch | Out-File $new_requirements_text_file + Get-Content "$src_root\tools\ci_build\github\linux\python\requirements.txt" | Select-String -pattern 'onnx' -notmatch | Out-File $new_requirements_text_file [string[]]$pip_args = "-m", "pip", "install", "-qq", "--disable-pip-version-check", "-r", $new_requirements_text_file &"python.exe" $pip_args diff --git a/tools/ci_build/github/windows/jar_esrp_dll.ps1 b/tools/ci_build/github/windows/jar_esrp_dll.ps1 new file mode 100644 index 0000000000000..8492d7591271b --- /dev/null +++ b/tools/ci_build/github/windows/jar_esrp_dll.ps1 @@ -0,0 +1,41 @@ +$instruction = $args[0] # extract or repack +$original_jar_file_directory = $args[1] # The directory where the original jar file is located +$original_jar_file_name = $args[2] # The name of the original jar file + +$original_jar_file_full_path = "$original_jar_file_directory\$original_jar_file_name" +$extracted_file_directory = "$original_jar_file_directory\jar_extracted_full_files" + +if ($instruction -eq "extract") { + Write-Host "Extracting the jar file $original_jar_file_full_path..." + & 7z x $original_jar_file_full_path -o"$extracted_file_directory" + if ($lastExitCode -ne 0) { + Write-Host -Object "7z extracting the jar file command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "Extracted files directory: $extracted_file_directory" + + Write-Host "Removing the original jar file..." + Remove-Item -Path "$original_jar_file_full_path" -Force + Write-Host "Removed the original jar file." +} +elseif ($instruction -eq "repack") { + Write-Host "Removing ESRP's CodeSignSummary file..." + # It is the summary generated by ESRP tool. It is not needed in the jar file. + Remove-Item -Path "$extracted_file_directory/CodeSignSummary*.*" -Force + Write-Host "Removed ESRP's CodeSignSummary file." + + Write-Host "Repacking the jar file from directory $extracted_file_directory..." + & 7z a "$original_jar_file_full_path" "$extracted_file_directory\*" + if ($lastExitCode -ne 0) { + Write-Host -Object "7z repacking the jar file command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "Repacked the jar file $original_jar_file_full_path." + + Write-Host "Removing the extracted files..." + Remove-Item -Path "$extracted_file_directory" -Recurse -Force + Write-Host "Removed the extracted files." +} +else { + Write-Host "Invalid instruction: $instruction" +} diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 6a660ecaa40d2..34ddd75da16fc 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -6,10 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64;%PATH% ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% @REM The default version is still cuda v12.2, because set cuda v11.8 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8\lib +set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-11.8\lib if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 ) else ( diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat index 4f2272e306570..03734293be5c4 100644 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ b/tools/ci_build/github/windows/setup_env_trt.bat @@ -6,6 +6,6 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 32c5ce7dd08d1..14aeff3df9c62 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -1,11 +1,12 @@ -# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline) +# packages used by transformers python unittest packaging -protobuf==3.20.2 -numpy==1.24.0 ; python_version < '3.12' -numpy==1.26.0 ; python_version >= '3.12' +# protobuf and numpy is same as tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +protobuf==4.21.12 +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' torch coloredlogs==15.0 -transformers==4.38.0 +transformers==4.46.3 parameterized>=0.8.1 psutil einops diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 0e9cd514d8aa5..b46d1e2559e46 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -30,14 +30,9 @@ "mac-ios-ci-pipeline.yml", "mac-ios-packaging-pipeline.yml", "mac-react-native-ci-pipeline.yml", - "orttraining-linux-ci-pipeline.yml", - "orttraining-linux-gpu-ci-pipeline.yml", - "orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml", - "orttraining-mac-ci-pipeline.yml", "win-ci-pipeline.yml", "win-gpu-dml-ci-pipeline.yml", "win-gpu-cuda-ci-pipeline.yml", - "win-gpu-training-ci-pipeline.yml", "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-gpu-webgpu-ci-pipeline.yml", diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 683d7b6be2aa8..11842f34ce45b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -138,7 +138,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "tvm", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +182,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Intel.ML.OnnxRuntime" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package description = ( "This package contains native shared library artifacts for all supported platforms of ONNX Runtime." @@ -225,7 +227,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("") @@ -375,13 +377,11 @@ def generate_files(line_list, args): "mklml": "mklml.dll", "openmp": "libiomp5md.dll", "dnnl": "dnnl.dll", - "tvm": "tvm.dll", "providers_shared_lib": "onnxruntime_providers_shared.dll", "dnnl_ep_shared_lib": "onnxruntime_providers_dnnl.dll", "tensorrt_ep_shared_lib": "onnxruntime_providers_tensorrt.dll", "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", - "tvm_ep_shared_lib": "onnxruntime_providers_tvm.lib", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -394,7 +394,6 @@ def generate_files(line_list, args): "mklml_1": "libmklml_gnu.so", "openmp": "libiomp5.so", "dnnl": "libdnnl.so.1", - "tvm": "libtvm.so.0.5.1", "providers_shared_lib": "libonnxruntime_providers_shared.so", "dnnl_ep_shared_lib": "libonnxruntime_providers_dnnl.so", "tensorrt_ep_shared_lib": "libonnxruntime_providers_tensorrt.so", @@ -456,14 +455,6 @@ def generate_files(line_list, args): + '" target="build\\native\\include" />' ) - if args.execution_provider == "tvm": - files_list.append( - "' - ) - if args.execution_provider == "openvino": files_list.append( "' ) - if args.execution_provider == "tvm": - files_list.append( - "' - ) - files_list.append( - "' - ) - - tvm_build_path = os.path.join(args.ort_build_path, args.build_config, "_deps", "tvm-build") - if is_windows(): - files_list.append( - "' - ) - else: - # TODO(agladyshev): Add support for Linux. - raise RuntimeError("Now only Windows is supported for TVM EP.") - if args.execution_provider == "rocm" or is_rocm_gpu_package and not is_ado_packaging_build: files_list.append( "' ) + if is_windows(): + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): + if dll_element.endswith("dll"): + files_list.append( + "' + ) + for tbb_element in os.listdir(tbb_list_path): + if tbb_element.endswith("dll"): + files_list.append( + "' + ) + if args.execution_provider == "cuda" or is_cuda_gpu_win_sub_package and not is_ado_packaging_build: files_list.append( "" ) - # Process tvm dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["tvm"])): - files_list.append( - "" - ) - # Some tools to be packaged in nightly debug build only, should not be released # These are copied to the runtimes folder for convenience of loading with the dlls # NOTE: nuget gives a spurious error on linux if these aren't in a separate directory to the library so diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 80420316c8bc3..228c8016170d9 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -35,20 +35,22 @@ def get_pipeline_names(): # mac "MacOS CI Pipeline", # training - "orttraining-amd-gpu-ci-pipeline", "orttraining-linux-ci-pipeline", "orttraining-linux-gpu-ci-pipeline", - "orttraining-ortmodule-distributed", # checks "onnxruntime-binary-size-checks-ci-pipeline", # big models "Big Models", # android "Linux Android Emulator QNN CI Pipeline", - # not currently required, but running ensures we're hitting all mobile platforms + # not currently required, but running these like internal PRs. "Android CI Pipeline", "iOS CI Pipeline", "ONNX Runtime React Native CI Pipeline", + "CoreML CI Pipeline", + "Linux DNNL CI Pipeline", + "Linux MIGraphX CI Pipeline", + "Linux ROCm CI Pipeline", ] return pipelines diff --git a/tools/python/upload_and_run_browserstack_tests.py b/tools/python/upload_and_run_browserstack_tests.py index 8751368e1b2fc..a4da87e4fe435 100644 --- a/tools/python/upload_and_run_browserstack_tests.py +++ b/tools/python/upload_and_run_browserstack_tests.py @@ -29,13 +29,16 @@ def upload_apk_parse_json(post_url, apk_path, id, token): return response_to_json(response) -def browserstack_build_request(devices, app_url, test_suite_url, test_platform, id, token): +def browserstack_build_request(devices, app_url, test_suite_url, test_platform, id, token, project, build_tag): headers = {} json_data = { "devices": devices, "app": app_url, "testSuite": test_suite_url, + "project": project, + "buildTag": build_tag, + "deviceLogs": True, } build_response = requests.post( @@ -78,22 +81,24 @@ def build_query_loop(build_id, test_platform, id, token): "--test_platform", type=str, help="Testing platform", choices=["espresso", "xcuitest"], required=True ) parser.add_argument( - "--app_apk_path", + "--app_path", type=Path, help=( - "Path to the app APK. " - "Typically, the app APK is in " + "Path to the app file. " + "For Android, typically, the app file (the APK) is in " "{build_output_dir}/android_test/android/app/build/outputs/apk/debug/app-debug.apk" + ". For iOS, you will have to build an IPA file from the test app, which is built from the .xcarchive path" ), required=True, ) parser.add_argument( - "--test_apk_path", + "--test_path", type=Path, help=( - "Path to the test APK. " + "Path to the test suite file. " "Typically, the test APK is in " "{build_output_dir}/android_test/android/app/build/outputs/apk/androidTest/debug/app-debug-androidTest.apk" + ". For iOS, you will have to create a .zip of the tests. After manually building the tests, the tests that you need to zip will be in {{Xcode DerivedData Folder Path}}/Build/Products" ), required=True, ) @@ -102,10 +107,17 @@ def build_query_loop(build_id, test_platform, id, token): type=str, nargs="+", help="List of devices to run the tests on. For more info, " - "see https://www.browserstack.com/docs/app-automate/espresso/specify-devices", + "see https://www.browserstack.com/docs/app-automate/espresso/specify-devices (Android) or https://www.browserstack.com/docs/app-automate/xcuitest/specify-devices (iOS)", required=True, ) + parser.add_argument( + "--project", + type=str, + help="Identifier to logically group multiple builds together", + default="ONNXRuntime tests", + ) + parser.add_argument("--build_tag", type=str, help="Identifier to tag the build with a unique name", default="") args = parser.parse_args() try: @@ -121,13 +133,13 @@ def build_query_loop(build_id, test_platform, id, token): # Upload the app and test suites upload_app_json = upload_apk_parse_json( f"https://api-cloud.browserstack.com/app-automate/{args.test_platform}/v2/app", - args.app_apk_path, + args.app_path, browserstack_id, browserstack_token, ) upload_test_json = upload_apk_parse_json( f"https://api-cloud.browserstack.com/app-automate/{args.test_platform}/v2/test-suite", - args.test_apk_path, + args.test_path, browserstack_id, browserstack_token, ) @@ -140,6 +152,8 @@ def build_query_loop(build_id, test_platform, id, token): args.test_platform, browserstack_id, browserstack_token, + args.project, + args.build_tag, ) # Get build status until the tests are no longer running diff --git a/tools/python/util/android/android.py b/tools/python/util/android/android.py index dd2dcce01bf4a..24004d6be761d 100644 --- a/tools/python/util/android/android.py +++ b/tools/python/util/android/android.py @@ -4,6 +4,7 @@ import collections import contextlib import datetime +import os import signal import subprocess import time @@ -105,8 +106,15 @@ def _stop_process_with_pid(pid: int): def start_emulator( - sdk_tool_paths: SdkToolPaths, avd_name: str, extra_args: typing.Optional[typing.Sequence[str]] = None + sdk_tool_paths: SdkToolPaths, + avd_name: str, + extra_args: typing.Optional[typing.Sequence[str]] = None, + timeout_minutes: int = 20, ) -> subprocess.Popen: + if check_emulator_running_using_avd_name(avd_name=avd_name): + raise RuntimeError( + f"An emulator with avd_name{avd_name} is already running. Please close it before starting a new one." + ) with contextlib.ExitStack() as emulator_stack, contextlib.ExitStack() as waiter_stack: emulator_args = [ sdk_tool_paths.emulator, @@ -122,6 +130,7 @@ def start_emulator( "-gpu", "guest", "-delay-adb", + "-verbose", ] # For Linux CIs we must use "-no-window" otherwise you'll get @@ -155,9 +164,9 @@ def start_emulator( waiter_stack.callback(_stop_process, waiter_process) # poll subprocesses. - # allow 20 minutes for startup as some CIs are slow. TODO: Make timeout configurable if needed. + # allow 20 minutes for startup as some CIs are slow. sleep_interval_seconds = 10 - end_time = datetime.datetime.now() + datetime.timedelta(minutes=20) + end_time = datetime.datetime.now() + datetime.timedelta(minutes=timeout_minutes) while True: waiter_ret, emulator_ret = waiter_process.poll(), emulator_process.poll() @@ -205,13 +214,127 @@ def start_emulator( _log.debug(f"sys.boot_completed='{getprop_value}'. Sleeping for {sleep_interval_seconds} before retrying.") time.sleep(sleep_interval_seconds) + # Verify if the emulator is now running + if not check_emulator_running_using_avd_name(avd_name=avd_name): + raise RuntimeError("Emulator failed to start.") return emulator_process -def stop_emulator(emulator_proc_or_pid: typing.Union[subprocess.Popen, int]): +def check_emulator_running_using_avd_name(avd_name: str) -> bool: + """ + Check if an emulator is running based on the provided AVD name. + :param avd_name: Name of the Android Virtual Device (AVD) to check. + :return: True if an emulator with the given AVD name is running, False otherwise. + """ + try: + # Step 1: List running devices + result = subprocess.check_output(["adb", "devices"], text=True).strip() + _log.info(f"adb devices output:\n{result}") + running_emulators = [line.split("\t")[0] for line in result.splitlines()[1:] if "emulator" in line] + + if not running_emulators: + _log.debug("No emulators running.") + return False # No emulators running + + # Step 2: Check each running emulator's AVD name + for emulator in running_emulators: + try: + avd_info = ( + subprocess.check_output(["adb", "-s", emulator, "emu", "avd", "name"], text=True) + .strip() + .split("\n")[0] + ) + _log.debug(f"AVD name for emulator {emulator}: {avd_info}") + if avd_info == avd_name: + return True + except subprocess.SubprocessError: + _log.warning(f"Error checking AVD name for emulator: {emulator}") + continue # Skip if there's an issue querying a specific emulator + + _log.warning(f"No emulator running with AVD name: {avd_name}") + return False # No matching AVD name found + except subprocess.SubprocessError as e: + _log.warning(f"Error checking emulator status: {e}") + return False + + +def check_emulator_running_using_process(emulator_proc: subprocess.Popen) -> bool: + """Check if the emulator process is running based on a Popen instance.""" + return emulator_proc.poll() is None + + +def check_emulator_running_using_pid(emulator_pid: int) -> bool: + """Check if the emulator process is running based on PID.""" + try: + os.kill(emulator_pid, 0) # Signal 0 checks process existence + return True + except OSError: + return False + + +def stop_emulator_by_proc(emulator_proc: subprocess.Popen, timeout_seconds: int = 120): + """ + Stops the emulator process using a subprocess.Popen instance. + :param emulator_proc: The emulator process as a subprocess.Popen instance. + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ + if not check_emulator_running_using_process(emulator_proc): + _log.warning("The specified emulator process is not running.") + return + + _log.info("Stopping emulator using subprocess.Popen instance.") + _stop_process(emulator_proc) + + # Wait for the process to stop + interval = 5 + end_time = datetime.datetime.now() + datetime.timedelta(seconds=timeout_seconds) + + while check_emulator_running_using_process(emulator_proc): + if datetime.datetime.now() > end_time: + raise RuntimeError(f"Failed to stop the emulator within the specified timeout = {timeout_seconds} seconds.") + _log.debug("Emulator still running. Checking again in 5 seconds...") + time.sleep(interval) + + _log.info("Emulator stopped successfully.") + + +def stop_emulator_by_pid(emulator_pid: int, timeout_seconds: int = 120): + """ + Stops the emulator process using a PID. + :param emulator_pid: The emulator process PID. + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ + if not check_emulator_running_using_pid(emulator_pid): + _log.warning(f"No emulator process with PID {emulator_pid} is currently running.") + return + + _log.info(f"Stopping emulator with PID: {emulator_pid}") + _stop_process_with_pid(emulator_pid) + + # Wait for the process to stop + interval = 5 + end_time = datetime.datetime.now() + datetime.timedelta(seconds=timeout_seconds) + + while check_emulator_running_using_pid(emulator_pid): + if datetime.datetime.now() > end_time: + raise RuntimeError( + f"Failed to stop the emulator with PID {emulator_pid} within the specified timeout = {timeout_seconds} seconds." + ) + _log.debug("Emulator still running. Checking again in 5 seconds...") + time.sleep(interval) + + _log.info("Emulator stopped successfully.") + + +def stop_emulator(emulator_proc_or_pid: typing.Union[subprocess.Popen, int], timeout_seconds: int = 120): + """ + Stops the emulator process, checking its running status before and after stopping. + :param emulator_proc_or_pid: The emulator process (subprocess.Popen) or PID (int). + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ if isinstance(emulator_proc_or_pid, subprocess.Popen): - _stop_process(emulator_proc_or_pid) + stop_emulator_by_proc(emulator_proc_or_pid, timeout_seconds) elif isinstance(emulator_proc_or_pid, int): - _stop_process_with_pid(emulator_proc_or_pid) + stop_emulator_by_pid(emulator_proc_or_pid, timeout_seconds) else: raise ValueError("Expected either a PID or subprocess.Popen instance.") diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh index 39d9ed432a1dc..53d350cf30611 100755 --- a/tools/scripts/python_test.sh +++ b/tools/scripts/python_test.sh @@ -7,15 +7,12 @@ export build_dir=$2 export config=$3 # it's for manylinux image -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH echo Install Python Deps cp $src_dir/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $build_dir/requirements.txt python3 -m pip install -r $build_dir/requirements.txt -mkdir -p $build_dir/requirements_torch_cpu/ -cp $src_dir/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $build_dir/requirements_torch_cpu/requirements.txt -python3 -m pip install -r $build_dir/requirements_torch_cpu/requirements.txt python3 -m pip list | grep onnx echo Install $config python package @@ -23,6 +20,5 @@ rm -rf $build_dir/$config/onnxruntime $build_dir/$config/pybind11 python3 -m pip install $build_dir/$config/dist/*.whl echo Run $config unit tests -pushd $build_dir/$config/ -python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests --enable_transformers_tool_test --ctest_path "" -popd +cd $build_dir/$config/ +python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests --enable_transformers_tool_test diff --git a/tools/scripts/symbolic_shape_infer_test.sh b/tools/scripts/symbolic_shape_infer_test.sh index d8d50c5e3fa91..6717c1d5a9f59 100755 --- a/tools/scripts/symbolic_shape_infer_test.sh +++ b/tools/scripts/symbolic_shape_infer_test.sh @@ -5,7 +5,7 @@ set -ex export build_dir=$1 # it's for manylinux image -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH echo Run symbolic shape infer test pushd $build_dir/Release/ diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index 1763290718a8f..f1272fc1b8626 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -1,8 +1,8 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. + +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "lib/Api/pch/pch.h" - #include "HardwareCoreEnumerator.h" namespace WINMLP { @@ -88,22 +88,33 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { #if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + bool isIntelSpecifiedPlatform = false; + const int kVendorID_IntelSpecifiedPlatformIDs[3] = { + // ExtendedModel,ExtendedFamily,Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + int regs_leaf0[4]; - int regs_leaf7[4]; + int regs_leaf1[4]; __cpuid(regs_leaf0, 0); - __cpuid(regs_leaf7, 0x7); + __cpuid(regs_leaf1, 0x1); auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); - auto isHybrid = (regs_leaf7[3] & (1 << 15)); + for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { + isIntelSpecifiedPlatform = true; + } + } - if (isIntel && isHybrid) { + if (isIntel && isIntelSpecifiedPlatform) { // We want to use the number of physical cores, but exclude cores without an LLC return cores.LLCCores; } #endif - return cores.PhysicalCores; }